Skip to content

Commit

Permalink
moved legacy-version of metrics and updated classes for keras3
Browse files Browse the repository at this point in the history
  • Loading branch information
mbarbetti committed Jun 19, 2024
1 parent 136f179 commit 23b1865
Show file tree
Hide file tree
Showing 29 changed files with 222 additions and 109 deletions.
31 changes: 23 additions & 8 deletions src/pidgan/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,23 @@
from .Accuracy import Accuracy
from .BinaryCrossentropy import BinaryCrossentropy
from .JSDivergence import JSDivergence
from .KLDivergence import KLDivergence
from .MeanAbsoluteError import MeanAbsoluteError
from .MeanSquaredError import MeanSquaredError
from .RootMeanSquaredError import RootMeanSquaredError
from .WassersteinDistance import WassersteinDistance
import keras as k

k_vrs = k.__version__.split(".")[:2]
k_vrs = float(".".join([n for n in k_vrs]))

if k_vrs >= 3.0:
from .k3.Accuracy import Accuracy
from .k3.BinaryCrossentropy import BinaryCrossentropy
from .k3.JSDivergence import JSDivergence
from .k3.KLDivergence import KLDivergence
from .k3.MeanAbsoluteError import MeanAbsoluteError
from .k3.MeanSquaredError import MeanSquaredError
from .k3.RootMeanSquaredError import RootMeanSquaredError
from .k3.WassersteinDistance import WassersteinDistance
else:
from .k2.Accuracy import Accuracy
from .k2.BinaryCrossentropy import BinaryCrossentropy
from .k2.JSDivergence import JSDivergence
from .k2.KLDivergence import KLDivergence
from .k2.MeanAbsoluteError import MeanAbsoluteError
from .k2.MeanSquaredError import MeanSquaredError
from .k2.RootMeanSquaredError import RootMeanSquaredError
from .k2.WassersteinDistance import WassersteinDistance
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import tensorflow as tf
from tensorflow import keras
import keras

from pidgan.metrics.BaseMetric import BaseMetric
from pidgan.metrics.k2.BaseMetric import BaseMetric


class Accuracy(BaseMetric):
def __init__(self, name="accuracy", dtype=None, threshold=0.5) -> None:
super().__init__(name, dtype)
super().__init__(name=name, dtype=dtype)
self._accuracy = keras.metrics.BinaryAccuracy(
name=name, dtype=dtype, threshold=threshold
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from tensorflow import keras
import keras


class BaseMetric(keras.metrics.Metric):
def __init__(self, name="metric", dtype=None) -> None:
super().__init__(name, dtype)
super().__init__(name=name, dtype=dtype)
self._metric_values = self.add_weight(
name=f"{name}_values", initializer="zeros"
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import tensorflow as tf
from tensorflow import keras
import keras

from pidgan.metrics.BaseMetric import BaseMetric
from pidgan.metrics.k2.BaseMetric import BaseMetric


class BinaryCrossentropy(BaseMetric):
def __init__(
self, name="bce", dtype=None, from_logits=False, label_smoothing=0.0
) -> None:
super().__init__(name, dtype)
super().__init__(name=name, dtype=dtype)
self._bce = keras.metrics.BinaryCrossentropy(
name=name,
dtype=dtype,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import tensorflow as tf
from tensorflow import keras
import keras

from pidgan.metrics.BaseMetric import BaseMetric
from pidgan.metrics.k2.BaseMetric import BaseMetric


class JSDivergence(BaseMetric):
def __init__(self, name="js_div", dtype=None) -> None:
super().__init__(name, dtype)
super().__init__(name=name, dtype=dtype)
self._kl_div = keras.metrics.KLDivergence(name=name, dtype=dtype)

def update_state(self, y_true, y_pred, sample_weight=None) -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from tensorflow import keras
import keras

from pidgan.metrics.BaseMetric import BaseMetric
from pidgan.metrics.k2.BaseMetric import BaseMetric


class KLDivergence(BaseMetric):
def __init__(self, name="kl_div", dtype=None) -> None:
super().__init__(name, dtype)
super().__init__(name=name, dtype=dtype)
self._kl_div = keras.metrics.KLDivergence(name=name, dtype=dtype)

def update_state(self, y_true, y_pred, sample_weight=None) -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from tensorflow import keras
import keras

from pidgan.metrics.BaseMetric import BaseMetric
from pidgan.metrics.k2.BaseMetric import BaseMetric


class MeanAbsoluteError(BaseMetric):
def __init__(self, name="mae", dtype=None, **kwargs) -> None:
super().__init__(name, dtype, **kwargs)
def __init__(self, name="mae", dtype=None) -> None:
super().__init__(name=name, dtype=dtype)
self._mae = keras.metrics.MeanAbsoluteError(name=name, dtype=dtype)

def update_state(self, y_true, y_pred, sample_weight=None) -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from tensorflow import keras
import keras

from pidgan.metrics.BaseMetric import BaseMetric
from pidgan.metrics.k2.BaseMetric import BaseMetric


class MeanSquaredError(BaseMetric):
def __init__(self, name="mse", dtype=None, **kwargs) -> None:
super().__init__(name, dtype, **kwargs)
def __init__(self, name="mse", dtype=None) -> None:
super().__init__(name=name, dtype=dtype)
self._mse = keras.metrics.MeanSquaredError(name=name, dtype=dtype)

def update_state(self, y_true, y_pred, sample_weight=None) -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from tensorflow import keras
import keras

from pidgan.metrics.BaseMetric import BaseMetric
from pidgan.metrics.k2.BaseMetric import BaseMetric


class RootMeanSquaredError(BaseMetric):
def __init__(self, name="rmse", dtype=None, **kwargs):
super().__init__(name, dtype, **kwargs)
def __init__(self, name="rmse", dtype=None):
super().__init__(name=name, dtype=dtype)
self._rmse = keras.metrics.RootMeanSquaredError(name=name, dtype=dtype)

def update_state(self, y_true, y_pred, sample_weight=None):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import tensorflow as tf

from pidgan.metrics.BaseMetric import BaseMetric
from pidgan.metrics.k2.BaseMetric import BaseMetric


class WassersteinDistance(BaseMetric):
def __init__(self, name="wass_dist", dtype=None) -> None:
super().__init__(name, dtype)
super().__init__(name=name, dtype=dtype)

def update_state(self, y_true, y_pred, sample_weight=None) -> None:
if sample_weight is not None:
Expand Down
Empty file.
17 changes: 17 additions & 0 deletions src/pidgan/metrics/k3/Accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import keras as k

from pidgan.metrics.k3.BaseMetric import BaseMetric


class Accuracy(BaseMetric):
def __init__(self, name="accuracy", dtype=None, threshold=0.5) -> None:
super().__init__(name=name, dtype=dtype)
self._accuracy = k.metrics.BinaryAccuracy(
name=name, dtype=dtype, threshold=threshold
)

def update_state(self, y_true, y_pred, sample_weight=None) -> None:
state = self._accuracy(
k.ops.ones_like(y_pred), y_pred, sample_weight=sample_weight
)
self._metric_values.assign(state)
18 changes: 18 additions & 0 deletions src/pidgan/metrics/k3/BaseMetric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import keras as k


class BaseMetric(k.metrics.Metric):
def __init__(self, name="metric", dtype=None) -> None:
super().__init__(name=name, dtype=dtype)
self._metric_values = self.add_weight(
name=f"{name}_values", initializer="zeros"
)

def update_state(self, y_true, y_pred, sample_weight=None) -> None:
raise NotImplementedError(
"Only `BaseMetric` subclasses have the "
"`update_state()` method implemented."
)

def result(self):
return self._metric_values
20 changes: 20 additions & 0 deletions src/pidgan/metrics/k3/BinaryCrossentropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import keras as k

from pidgan.metrics.k3.BaseMetric import BaseMetric


class BinaryCrossentropy(BaseMetric):
def __init__(
self, name="bce", dtype=None, from_logits=False, label_smoothing=0.0
) -> None:
super().__init__(name=name, dtype=dtype)
self._bce = k.metrics.BinaryCrossentropy(
name=name,
dtype=dtype,
from_logits=from_logits,
label_smoothing=label_smoothing,
)

def update_state(self, y_true, y_pred, sample_weight=None) -> None:
state = self._bce(k.ops.ones_like(y_pred), y_pred, sample_weight=sample_weight)
self._metric_values.assign(state)
21 changes: 21 additions & 0 deletions src/pidgan/metrics/k3/JSDivergence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import keras as k

from pidgan.metrics.k3.BaseMetric import BaseMetric


class JSDivergence(BaseMetric):
def __init__(self, name="js_div", dtype=None) -> None:
super().__init__(name=name, dtype=dtype)
self._kl_div = k.metrics.KLDivergence(name=name, dtype=dtype)

def update_state(self, y_true, y_pred, sample_weight=None) -> None:
dtype = self._kl_div(y_true, y_pred).dtype
y_true = k.ops.cast(y_true, dtype)
y_pred = k.ops.cast(y_pred, dtype)

state = 0.5 * self._kl_div(
y_true, 0.5 * (y_true + y_pred), sample_weight=sample_weight
) + 0.5 * self._kl_div(
y_pred, 0.5 * (y_true + y_pred), sample_weight=sample_weight
)
self._metric_values.assign(state)
13 changes: 13 additions & 0 deletions src/pidgan/metrics/k3/KLDivergence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import keras as k

from pidgan.metrics.k3.BaseMetric import BaseMetric


class KLDivergence(BaseMetric):
def __init__(self, name="kl_div", dtype=None) -> None:
super().__init__(name=name, dtype=dtype)
self._kl_div = k.metrics.KLDivergence(name=name, dtype=dtype)

def update_state(self, y_true, y_pred, sample_weight=None) -> None:
state = self._kl_div(y_true, y_pred, sample_weight=sample_weight)
self._metric_values.assign(state)
13 changes: 13 additions & 0 deletions src/pidgan/metrics/k3/MeanAbsoluteError.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import keras as k

from pidgan.metrics.k3.BaseMetric import BaseMetric


class MeanAbsoluteError(BaseMetric):
def __init__(self, name="mae", dtype=None) -> None:
super().__init__(name=name, dtype=dtype)
self._mae = k.metrics.MeanAbsoluteError(name=name, dtype=dtype)

def update_state(self, y_true, y_pred, sample_weight=None) -> None:
state = self._mae(y_true, y_pred, sample_weight=sample_weight)
self._metric_values.assign(state)
13 changes: 13 additions & 0 deletions src/pidgan/metrics/k3/MeanSquaredError.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import keras as k

from pidgan.metrics.k3.BaseMetric import BaseMetric


class MeanSquaredError(BaseMetric):
def __init__(self, name="mse", dtype=None) -> None:
super().__init__(name=name, dtype=dtype)
self._mse = k.metrics.MeanSquaredError(name=name, dtype=dtype)

def update_state(self, y_true, y_pred, sample_weight=None) -> None:
state = self._mse(y_true, y_pred, sample_weight=sample_weight)
self._metric_values.assign(state)
13 changes: 13 additions & 0 deletions src/pidgan/metrics/k3/RootMeanSquaredError.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import keras as k

from pidgan.metrics.k3.BaseMetric import BaseMetric


class RootMeanSquaredError(BaseMetric):
def __init__(self, name="rmse", dtype=None):
super().__init__(name=name, dtype=dtype)
self._rmse = k.metrics.RootMeanSquaredError(name=name, dtype=dtype)

def update_state(self, y_true, y_pred, sample_weight=None):
state = self._rmse(y_true, y_pred, sample_weight=sample_weight)
self._metric_values.assign(state)
18 changes: 18 additions & 0 deletions src/pidgan/metrics/k3/WassersteinDistance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import keras as k

from pidgan.metrics.k3.BaseMetric import BaseMetric


class WassersteinDistance(BaseMetric):
def __init__(self, name="wass_dist", dtype=None) -> None:
super().__init__(name=name, dtype=dtype)

def update_state(self, y_true, y_pred, sample_weight=None) -> None:
if sample_weight is not None:
state = k.ops.sum(sample_weight * (y_pred - y_true))
state /= k.ops.sum(sample_weight)
else:
state = k.ops.mean(y_pred - y_true)
state = k.ops.cast(state, self.dtype)
print("debug:", self.dtype)
self._metric_values.assign(state)
Empty file.
11 changes: 3 additions & 8 deletions tests/metrics/test_Accuracy_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,8 @@ def test_metric_configuration(metric):
assert isinstance(metric.name, str)


def test_metric_use_no_weights(metric):
metric.update_state(y_true, y_pred, sample_weight=None)
res = metric.result().numpy()
assert res


def test_metric_use_with_weights(metric):
metric.update_state(y_true, y_pred, sample_weight=weight)
@pytest.mark.parametrize("sample_weight", [None, weight])
def test_metric_use(metric, sample_weight):
metric.update_state(y_true, y_pred, sample_weight=sample_weight)
res = metric.result().numpy()
assert res
21 changes: 4 additions & 17 deletions tests/metrics/test_BinaryCrossentropy_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,15 @@ def test_metric_configuration(metric):


@pytest.mark.parametrize("from_logits", [False, True])
def test_metric_use_no_weights(from_logits):
@pytest.mark.parametrize("sample_weight", [None, weight])
def test_metric_use(from_logits, sample_weight):
from pidgan.metrics import BinaryCrossentropy

metric = BinaryCrossentropy(from_logits=from_logits, label_smoothing=0.0)
if from_logits:
metric.update_state(y_true, y_pred_logits, sample_weight=None)
metric.update_state(y_true, y_pred_logits, sample_weight=sample_weight)
res = metric.result().numpy()
else:
metric.update_state(y_true, y_pred, sample_weight=None)
res = metric.result().numpy()
assert res


@pytest.mark.parametrize("from_logits", [False, True])
def test_metric_use_with_weights(from_logits):
from pidgan.metrics import BinaryCrossentropy

metric = BinaryCrossentropy(from_logits=from_logits, label_smoothing=0.0)
if from_logits:
metric.update_state(y_true, y_pred_logits, sample_weight=weight)
res = metric.result().numpy()
else:
metric.update_state(y_true, y_pred, sample_weight=weight)
metric.update_state(y_true, y_pred, sample_weight=sample_weight)
res = metric.result().numpy()
assert res
11 changes: 3 additions & 8 deletions tests/metrics/test_JSDivergence_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,8 @@ def test_metric_configuration(metric):
assert isinstance(metric.name, str)


def test_metric_use_no_weights(metric):
metric.update_state(y_true, y_pred, sample_weight=None)
res = metric.result().numpy()
assert res


def test_metric_use_with_weights(metric):
metric.update_state(y_true, y_pred, sample_weight=weight)
@pytest.mark.parametrize("sample_weight", [None, weight])
def test_metric_use(metric, sample_weight):
metric.update_state(y_true, y_pred, sample_weight=sample_weight)
res = metric.result().numpy()
assert res
Loading

0 comments on commit 23b1865

Please sign in to comment.