From 14928921e2f6d5b049d8dcfa07982e9ca351a402 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Thu, 4 Aug 2022 20:41:15 +0200 Subject: [PATCH] Add `TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING` (#18469) Co-authored-by: ydshieh --- src/transformers/__init__.py | 2 ++ src/transformers/models/auto/__init__.py | 2 ++ .../data2vec/modeling_tf_data2vec_vision.py | 4 ++-- .../models/segformer/modeling_tf_segformer.py | 16 ++++++++-------- src/transformers/utils/dummy_tf_objects.py | 3 +++ .../segformer/test_modeling_tf_segformer.py | 4 ++++ tests/test_modeling_tf_common.py | 18 ++++++++++++++++-- 7 files changed, 37 insertions(+), 12 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index e8cfd47f3d3b37..5e1e95c6291b78 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -2088,6 +2088,7 @@ "TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", "TF_MODEL_FOR_PRETRAINING_MAPPING", "TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING", + "TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING", "TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING", "TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", "TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING", @@ -4582,6 +4583,7 @@ TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, TF_MODEL_FOR_PRETRAINING_MAPPING, TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING, + TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index b04c2420ef963e..139d4feda336e0 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -111,6 +111,7 @@ "TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", "TF_MODEL_FOR_PRETRAINING_MAPPING", "TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING", + "TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING", "TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING", "TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", "TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING", @@ -253,6 +254,7 @@ TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, TF_MODEL_FOR_PRETRAINING_MAPPING, TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING, + TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, diff --git a/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py b/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py index e09cbfb9c42a6e..33e9921cc9a58c 100644 --- a/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py @@ -1352,8 +1352,8 @@ def masked_loss(real, pred): loss_ = loss_fct(real, pred) mask = tf.cast(mask, dtype=loss_.dtype) loss_ *= mask - - return tf.reduce_sum(loss_) / tf.reduce_sum(mask) + reduced_masked_loss = tf.reduce_sum(loss_) / tf.reduce_sum(mask) + return tf.reshape(reduced_masked_loss, (1,)) main_loss = masked_loss(labels, upsampled_logits) auxiliary_loss = masked_loss(labels, upsampled_auxiliary_logits) diff --git a/src/transformers/models/segformer/modeling_tf_segformer.py b/src/transformers/models/segformer/modeling_tf_segformer.py index 25350d1c82559a..c2f4b2ff0c7cd8 100644 --- a/src/transformers/models/segformer/modeling_tf_segformer.py +++ b/src/transformers/models/segformer/modeling_tf_segformer.py @@ -201,9 +201,9 @@ def __init__(self, config: SegformerConfig, hidden_size: int, **kwargs): self.dense = tf.keras.layers.Dense(hidden_size, name="dense") self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) return hidden_states @@ -276,13 +276,13 @@ def __init__( self.dense2 = tf.keras.layers.Dense(out_features, name="dense2") self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) - def call(self, hidden_states: tf.Tensor, height: int, width: int) -> tf.Tensor: + def call(self, hidden_states: tf.Tensor, height: int, width: int, training: bool = False) -> tf.Tensor: hidden_states = self.dense1(hidden_states) hidden_states = self.depthwise_convolution(hidden_states, height, width) hidden_states = self.intermediate_act_fn(hidden_states) - hidden_states = self.dropout(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dense2(hidden_states) - hidden_states = self.dropout(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) return hidden_states @@ -749,7 +749,7 @@ def __init__(self, config: SegformerConfig, **kwargs): self.config = config - def call(self, encoder_hidden_states): + def call(self, encoder_hidden_states, training: bool = False): batch_size = shape_list(encoder_hidden_states[-1])[0] all_hidden_states = () @@ -773,9 +773,9 @@ def call(self, encoder_hidden_states): all_hidden_states += (encoder_hidden_state,) hidden_states = self.linear_fuse(tf.concat(all_hidden_states[::-1], axis=-1)) - hidden_states = self.batch_norm(hidden_states) + hidden_states = self.batch_norm(hidden_states, training=training) hidden_states = self.activation(hidden_states) - hidden_states = self.dropout(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) # logits of shape (batch_size, height/4, width/4, num_labels) logits = self.classifier(hidden_states) diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index 37b58cd8146601..6df601ca646af3 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -279,6 +279,9 @@ def __init__(self, *args, **kwargs): TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = None +TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = None + + TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = None diff --git a/tests/models/segformer/test_modeling_tf_segformer.py b/tests/models/segformer/test_modeling_tf_segformer.py index 6cc2c77fe935fc..d6a73e22192c3b 100644 --- a/tests/models/segformer/test_modeling_tf_segformer.py +++ b/tests/models/segformer/test_modeling_tf_segformer.py @@ -27,6 +27,7 @@ if is_tf_available(): + import numpy as np import tensorflow as tf from transformers import TFSegformerForImageClassification, TFSegformerForSemanticSegmentation, TFSegformerModel @@ -336,6 +337,9 @@ def recursive_check(tuple_object, dict_object): def test_dataset_conversion(self): super().test_dataset_conversion() + def check_keras_fit_results(self, val_loss1, val_loss2, atol=2e-1, rtol=2e-1): + self.assertTrue(np.allclose(val_loss1, val_loss2, atol=atol, rtol=rtol)) + @unittest.skipIf( not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0, reason="TF (<=2.8) does not support backprop for grouped convolutions on CPU.", diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index d63b1b32733e89..15855e6a1f40e6 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -62,11 +62,13 @@ from transformers import ( TF_MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, + TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING, TF_MODEL_FOR_MASKED_LM_MAPPING, TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, TF_MODEL_FOR_PRETRAINING_MAPPING, TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING, + TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, @@ -170,6 +172,15 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False) -> d inputs_dict["labels"] = tf.zeros( (self.model_tester.batch_size, self.model_tester.seq_length), dtype=tf.int32 ) + elif model_class in get_values(TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING): + num_patches = self.model_tester.image_size // self.model_tester.patch_size + inputs_dict["bool_masked_pos"] = tf.zeros( + (self.model_tester.batch_size, num_patches**2), dtype=tf.int32 + ) + elif model_class in get_values(TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING): + batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape + inputs_dict["labels"] = tf.zeros((self.model_tester.batch_size, height, width), dtype=tf.int32) + return inputs_dict def test_initialization(self): @@ -1389,6 +1400,9 @@ def test_loss_computation(self): self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1]) + def check_keras_fit_results(self, val_loss1, val_loss2, atol=1e-2, rtol=1e-3): + self.assertTrue(np.allclose(val_loss1, val_loss2, atol=atol, rtol=rtol)) + def test_keras_fit(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: @@ -1468,7 +1482,7 @@ def test_keras_fit(self): val_loss2 = history2.history["val_loss"][0] self.assertTrue(not isnan(val_loss2)) accuracy2 = {key: val[0] for key, val in history2.history.items() if key.endswith("accuracy")} - self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3)) + self.check_keras_fit_results(val_loss1, val_loss2) self.assertEqual(history1.history.keys(), history2.history.keys()) for key in history1.history.keys(): if not key.startswith("val_"): @@ -1494,7 +1508,7 @@ def test_keras_fit(self): val_loss3 = history3.history["val_loss"][0] self.assertTrue(not isnan(val_loss3)) accuracy3 = {key: val[0] for key, val in history3.history.items() if key.endswith("accuracy")} - self.assertTrue(np.allclose(val_loss1, val_loss3, atol=1e-2, rtol=1e-3)) + self.check_keras_fit_results(val_loss1, val_loss3) self.assertEqual(history1.history.keys(), history3.history.keys()) if metrics: self.assertTrue(len(accuracy1) == len(accuracy3) > 0, "Missing metrics!")