diff --git a/tensorflow_addons/layers/stochastic_depth.py b/tensorflow_addons/layers/stochastic_depth.py index d34043bdd2..1d68e2243c 100644 --- a/tensorflow_addons/layers/stochastic_depth.py +++ b/tensorflow_addons/layers/stochastic_depth.py @@ -65,7 +65,9 @@ def call(self, x, training=None): shortcut, residual = x # Random bernoulli variable indicating whether the branch should be kept or not or not - b_l = tf.keras.backend.random_bernoulli([], p=self.survival_probability) + b_l = tf.keras.backend.random_bernoulli( + [], p=self.survival_probability, dtype=self._compute_dtype_object + ) def _call_train(): return shortcut + b_l * residual diff --git a/tensorflow_addons/layers/tests/stochastic_depth_test.py b/tensorflow_addons/layers/tests/stochastic_depth_test.py index 1122016f57..31bce01de9 100644 --- a/tensorflow_addons/layers/tests/stochastic_depth_test.py +++ b/tensorflow_addons/layers/tests/stochastic_depth_test.py @@ -47,7 +47,9 @@ def test_with_mixed_precision_policy(): residual = np.asarray([[0.2, 0.4, 0.5]]) output = StochasticDepth()([shortcut, residual]) + assert output.dtype == policy.compute_dtype + output = StochasticDepth()([shortcut, residual], training=True) assert output.dtype == policy.compute_dtype