Skip to content

Commit

Permalink
Rename deprecated method (#2533)
Browse files Browse the repository at this point in the history
* Rename deprecated method
  • Loading branch information
RyanGoslingsBugle authored Jul 29, 2021
1 parent e9846e8 commit f30df43
Show file tree
Hide file tree
Showing 12 changed files with 65 additions and 23 deletions.
2 changes: 1 addition & 1 deletion tensorflow_addons/metrics/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ must:
Any PR which adds a new metric must ensure that:

1. It inherits from the `tf.keras.metrics.Metric` class.
2. Overrides the `update_state()`, `result()`, and `reset_states()` methods.
2. Overrides the `update_state()`, `result()`, and `reset_state()` methods.
3. Implements a `get_config()` method.

The implementation must also ensure that the following cases are well tested and supported:
Expand Down
8 changes: 7 additions & 1 deletion tensorflow_addons/metrics/cohens_kappa.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,11 +254,17 @@ def get_config(self):
base_config = super().get_config()
return {**base_config, **config}

def reset_states(self):
def reset_state(self):
"""Resets all of the metric state variables."""

for v in self.variables:
K.set_value(
v,
np.zeros((self.num_classes, self.num_classes), v.dtype.as_numpy_dtype),
)

def reset_states(self):
# Backwards compatibility alias of `reset_state`. New classes should
# only implement `reset_state`.
# Required in Tensorflow < 2.5.0
return self.reset_state()
8 changes: 7 additions & 1 deletion tensorflow_addons/metrics/f_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,16 @@ def get_config(self):
base_config = super().get_config()
return {**base_config, **config}

def reset_states(self):
def reset_state(self):
reset_value = tf.zeros(self.init_shape, dtype=self.dtype)
K.batch_set_value([(v, reset_value) for v in self.variables])

def reset_states(self):
# Backwards compatibility alias of `reset_state`. New classes should
# only implement `reset_state`.
# Required in Tensorflow < 2.5.0
return self.reset_state()


@tf.keras.utils.register_keras_serializable(package="Addons")
class F1Score(FBetaScore):
Expand Down
8 changes: 7 additions & 1 deletion tensorflow_addons/metrics/geometric_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,11 @@ def result(self) -> tf.Tensor:
ret = tf.math.exp(self.total / self.count)
return tf.cast(ret, dtype=self.dtype)

def reset_states(self) -> None:
def reset_state(self) -> None:
K.batch_set_value([(v, 0) for v in self.variables])

def reset_states(self):
# Backwards compatibility alias of `reset_state`. New classes should
# only implement `reset_state`.
# Required in Tensorflow < 2.5.0
return self.reset_state()
10 changes: 8 additions & 2 deletions tensorflow_addons/metrics/kendalls_tau.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(
self.preds_max = preds_max
self.actual_cutpoints = actual_cutpoints
self.preds_cutpoints = preds_cutpoints
self.reset_states()
self.reset_state()

def update_state(self, y_true, y_pred, sample_weight=None):
"""Accumulates ranks.
Expand Down Expand Up @@ -177,7 +177,7 @@ def get_config(self):
base_config = super().get_config()
return {**base_config, **config}

def reset_states(self):
def reset_state(self):
"""Resets all of the metric state variables."""
self.actual_cuts = tf.linspace(
tf.cast(self.actual_min, tf.float32),
Expand All @@ -201,3 +201,9 @@ def reset_states(self):
tf.zeros((0, 1), dtype=tf.int64), [], [self.preds_cutpoints]
)
self.n = 0

def reset_states(self):
# Backwards compatibility alias of `reset_state`. New classes should
# only implement `reset_state`.
# Required in Tensorflow < 2.5.0
return self.reset_state()
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,17 @@ def get_config(self):
base_config = super().get_config()
return {**base_config, **config}

def reset_states(self):
def reset_state(self):
"""Resets all of the metric state variables."""

for v in self.variables:
K.set_value(
v,
np.zeros((self.num_classes, self.num_classes), v.dtype.as_numpy_dtype),
)

def reset_states(self):
# Backwards compatibility alias of `reset_state`. New classes should
# only implement `reset_state`.
# Required in Tensorflow < 2.5.0
return self.reset_state()
8 changes: 7 additions & 1 deletion tensorflow_addons/metrics/multilabel_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,12 @@ def get_config(self):
base_config = super().get_config()
return {**base_config, **config}

def reset_states(self):
def reset_state(self):
reset_value = np.zeros(self.num_classes, dtype=np.int32)
K.batch_set_value([(v, reset_value) for v in self.variables])

def reset_states(self):
# Backwards compatibility alias of `reset_state`. New classes should
# only implement `reset_state`.
# Required in Tensorflow < 2.5.0
return self.reset_state()
8 changes: 7 additions & 1 deletion tensorflow_addons/metrics/r_square.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,16 @@ def result(self) -> tf.Tensor:

return r2_score

def reset_states(self) -> None:
def reset_state(self) -> None:
# The state of the metric will be reset at the start of each epoch.
K.batch_set_value([(v, np.zeros(v.shape)) for v in self.variables])

def reset_states(self):
# Backwards compatibility alias of `reset_state`. New classes should
# only implement `reset_state`.
# Required in Tensorflow < 2.5.0
return self.reset_state()

def get_config(self):
config = {
"y_shape": self.y_shape,
Expand Down
8 changes: 4 additions & 4 deletions tensorflow_addons/metrics/tests/cohens_kappa_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ def update_obj_states(obj1, obj2, obj3, actuals, preds, weights):


def reset_obj_states(obj1, obj2, obj3):
obj1.reset_states()
obj2.reset_states()
obj3.reset_states()
obj1.reset_state()
obj2.reset_state()
obj3.reset_state()


def check_results(objs, values):
Expand Down Expand Up @@ -128,7 +128,7 @@ def test_kappa_with_sample_weights():
check_results([kp_obj1, kp_obj2, kp_obj3], [-0.25473321, -0.38992332, -0.60695344])


def test_kappa_reset_states():
def test_kappa_reset_state():
# Initialize
kp_obj1, kp_obj2, kp_obj3 = initialize_vars()

Expand Down
4 changes: 2 additions & 2 deletions tensorflow_addons/metrics/tests/geometric_mean_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ def test_call_gmean(values, expected):
np.testing.assert_equal(len(values), count)


def test_reset_states():
def test_reset_state():
obj = GeometricMean()
obj.update_state([1, 2, 3, 4, 5])
obj.reset_states()
obj.reset_state()
assert obj.total.numpy() == 0.0
assert obj.count.numpy() == 0.0

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def test_keras_model():
model.fit(data, labels, epochs=1, batch_size=32, verbose=0)


def test_reset_states_graph():
def test_reset_state_graph():
gt_label = tf.constant(
[[0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0]], dtype=tf.float32
)
Expand All @@ -120,9 +120,9 @@ def test_reset_states_graph():
mcc.update_state(gt_label, preds)

@tf.function
def reset_states():
mcc.reset_states()
def reset_state():
mcc.reset_state()

reset_states()
reset_state()
# Check results
check_results(mcc, [0])
8 changes: 4 additions & 4 deletions tensorflow_addons/metrics/tests/r_square_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def update_obj_states(obj, actuals, preds, sample_weight=None):


@tf.function
def reset_obj_states(obj):
obj.reset_states()
def reset_obj_state(obj):
obj.reset_state()


def check_results(obj, value):
Expand Down Expand Up @@ -120,7 +120,7 @@ def test_r2_random_score():


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
def test_r2_reset_states():
def test_r2_reset_state():
actuals = tf.constant([100, 700, 40, 5.7], dtype=tf.float32)
preds = tf.constant([100, 700, 40, 5.7], dtype=tf.float32)
actuals = tf.cast(actuals, dtype=tf.float32)
Expand All @@ -130,7 +130,7 @@ def test_r2_reset_states():
# Update
update_obj_states(r2_obj, actuals, preds)
# Reset
reset_obj_states(r2_obj)
reset_obj_state(r2_obj)
# Check variables
check_variables(r2_obj, 0.0)

Expand Down

0 comments on commit f30df43

Please sign in to comment.