-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
moved legacy-version of metrics and updated classes for keras3
- Loading branch information
Showing
29 changed files
with
222 additions
and
109 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,23 @@ | ||
from .Accuracy import Accuracy | ||
from .BinaryCrossentropy import BinaryCrossentropy | ||
from .JSDivergence import JSDivergence | ||
from .KLDivergence import KLDivergence | ||
from .MeanAbsoluteError import MeanAbsoluteError | ||
from .MeanSquaredError import MeanSquaredError | ||
from .RootMeanSquaredError import RootMeanSquaredError | ||
from .WassersteinDistance import WassersteinDistance | ||
import keras as k | ||
|
||
k_vrs = k.__version__.split(".")[:2] | ||
k_vrs = float(".".join([n for n in k_vrs])) | ||
|
||
if k_vrs >= 3.0: | ||
from .k3.Accuracy import Accuracy | ||
from .k3.BinaryCrossentropy import BinaryCrossentropy | ||
from .k3.JSDivergence import JSDivergence | ||
from .k3.KLDivergence import KLDivergence | ||
from .k3.MeanAbsoluteError import MeanAbsoluteError | ||
from .k3.MeanSquaredError import MeanSquaredError | ||
from .k3.RootMeanSquaredError import RootMeanSquaredError | ||
from .k3.WassersteinDistance import WassersteinDistance | ||
else: | ||
from .k2.Accuracy import Accuracy | ||
from .k2.BinaryCrossentropy import BinaryCrossentropy | ||
from .k2.JSDivergence import JSDivergence | ||
from .k2.KLDivergence import KLDivergence | ||
from .k2.MeanAbsoluteError import MeanAbsoluteError | ||
from .k2.MeanSquaredError import MeanSquaredError | ||
from .k2.RootMeanSquaredError import RootMeanSquaredError | ||
from .k2.WassersteinDistance import WassersteinDistance |
6 changes: 3 additions & 3 deletions
6
src/pidgan/metrics/Accuracy.py → src/pidgan/metrics/k2/Accuracy.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
4 changes: 2 additions & 2 deletions
4
src/pidgan/metrics/BaseMetric.py → src/pidgan/metrics/k2/BaseMetric.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
6 changes: 3 additions & 3 deletions
6
src/pidgan/metrics/BinaryCrossentropy.py → src/pidgan/metrics/k2/BinaryCrossentropy.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
6 changes: 3 additions & 3 deletions
6
src/pidgan/metrics/JSDivergence.py → src/pidgan/metrics/k2/JSDivergence.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
6 changes: 3 additions & 3 deletions
6
src/pidgan/metrics/KLDivergence.py → src/pidgan/metrics/k2/KLDivergence.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
8 changes: 4 additions & 4 deletions
8
src/pidgan/metrics/MeanAbsoluteError.py → src/pidgan/metrics/k2/MeanAbsoluteError.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
8 changes: 4 additions & 4 deletions
8
src/pidgan/metrics/MeanSquaredError.py → src/pidgan/metrics/k2/MeanSquaredError.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
8 changes: 4 additions & 4 deletions
8
src/pidgan/metrics/RootMeanSquaredError.py → ...pidgan/metrics/k2/RootMeanSquaredError.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
4 changes: 2 additions & 2 deletions
4
src/pidgan/metrics/WassersteinDistance.py → src/pidgan/metrics/k2/WassersteinDistance.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import keras as k | ||
|
||
from pidgan.metrics.k3.BaseMetric import BaseMetric | ||
|
||
|
||
class Accuracy(BaseMetric): | ||
def __init__(self, name="accuracy", dtype=None, threshold=0.5) -> None: | ||
super().__init__(name=name, dtype=dtype) | ||
self._accuracy = k.metrics.BinaryAccuracy( | ||
name=name, dtype=dtype, threshold=threshold | ||
) | ||
|
||
def update_state(self, y_true, y_pred, sample_weight=None) -> None: | ||
state = self._accuracy( | ||
k.ops.ones_like(y_pred), y_pred, sample_weight=sample_weight | ||
) | ||
self._metric_values.assign(state) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import keras as k | ||
|
||
|
||
class BaseMetric(k.metrics.Metric): | ||
def __init__(self, name="metric", dtype=None) -> None: | ||
super().__init__(name=name, dtype=dtype) | ||
self._metric_values = self.add_weight( | ||
name=f"{name}_values", initializer="zeros" | ||
) | ||
|
||
def update_state(self, y_true, y_pred, sample_weight=None) -> None: | ||
raise NotImplementedError( | ||
"Only `BaseMetric` subclasses have the " | ||
"`update_state()` method implemented." | ||
) | ||
|
||
def result(self): | ||
return self._metric_values |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import keras as k | ||
|
||
from pidgan.metrics.k3.BaseMetric import BaseMetric | ||
|
||
|
||
class BinaryCrossentropy(BaseMetric): | ||
def __init__( | ||
self, name="bce", dtype=None, from_logits=False, label_smoothing=0.0 | ||
) -> None: | ||
super().__init__(name=name, dtype=dtype) | ||
self._bce = k.metrics.BinaryCrossentropy( | ||
name=name, | ||
dtype=dtype, | ||
from_logits=from_logits, | ||
label_smoothing=label_smoothing, | ||
) | ||
|
||
def update_state(self, y_true, y_pred, sample_weight=None) -> None: | ||
state = self._bce(k.ops.ones_like(y_pred), y_pred, sample_weight=sample_weight) | ||
self._metric_values.assign(state) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import keras as k | ||
|
||
from pidgan.metrics.k3.BaseMetric import BaseMetric | ||
|
||
|
||
class JSDivergence(BaseMetric): | ||
def __init__(self, name="js_div", dtype=None) -> None: | ||
super().__init__(name=name, dtype=dtype) | ||
self._kl_div = k.metrics.KLDivergence(name=name, dtype=dtype) | ||
|
||
def update_state(self, y_true, y_pred, sample_weight=None) -> None: | ||
dtype = self._kl_div(y_true, y_pred).dtype | ||
y_true = k.ops.cast(y_true, dtype) | ||
y_pred = k.ops.cast(y_pred, dtype) | ||
|
||
state = 0.5 * self._kl_div( | ||
y_true, 0.5 * (y_true + y_pred), sample_weight=sample_weight | ||
) + 0.5 * self._kl_div( | ||
y_pred, 0.5 * (y_true + y_pred), sample_weight=sample_weight | ||
) | ||
self._metric_values.assign(state) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import keras as k | ||
|
||
from pidgan.metrics.k3.BaseMetric import BaseMetric | ||
|
||
|
||
class KLDivergence(BaseMetric): | ||
def __init__(self, name="kl_div", dtype=None) -> None: | ||
super().__init__(name=name, dtype=dtype) | ||
self._kl_div = k.metrics.KLDivergence(name=name, dtype=dtype) | ||
|
||
def update_state(self, y_true, y_pred, sample_weight=None) -> None: | ||
state = self._kl_div(y_true, y_pred, sample_weight=sample_weight) | ||
self._metric_values.assign(state) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import keras as k | ||
|
||
from pidgan.metrics.k3.BaseMetric import BaseMetric | ||
|
||
|
||
class MeanAbsoluteError(BaseMetric): | ||
def __init__(self, name="mae", dtype=None) -> None: | ||
super().__init__(name=name, dtype=dtype) | ||
self._mae = k.metrics.MeanAbsoluteError(name=name, dtype=dtype) | ||
|
||
def update_state(self, y_true, y_pred, sample_weight=None) -> None: | ||
state = self._mae(y_true, y_pred, sample_weight=sample_weight) | ||
self._metric_values.assign(state) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import keras as k | ||
|
||
from pidgan.metrics.k3.BaseMetric import BaseMetric | ||
|
||
|
||
class MeanSquaredError(BaseMetric): | ||
def __init__(self, name="mse", dtype=None) -> None: | ||
super().__init__(name=name, dtype=dtype) | ||
self._mse = k.metrics.MeanSquaredError(name=name, dtype=dtype) | ||
|
||
def update_state(self, y_true, y_pred, sample_weight=None) -> None: | ||
state = self._mse(y_true, y_pred, sample_weight=sample_weight) | ||
self._metric_values.assign(state) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import keras as k | ||
|
||
from pidgan.metrics.k3.BaseMetric import BaseMetric | ||
|
||
|
||
class RootMeanSquaredError(BaseMetric): | ||
def __init__(self, name="rmse", dtype=None): | ||
super().__init__(name=name, dtype=dtype) | ||
self._rmse = k.metrics.RootMeanSquaredError(name=name, dtype=dtype) | ||
|
||
def update_state(self, y_true, y_pred, sample_weight=None): | ||
state = self._rmse(y_true, y_pred, sample_weight=sample_weight) | ||
self._metric_values.assign(state) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import keras as k | ||
|
||
from pidgan.metrics.k3.BaseMetric import BaseMetric | ||
|
||
|
||
class WassersteinDistance(BaseMetric): | ||
def __init__(self, name="wass_dist", dtype=None) -> None: | ||
super().__init__(name=name, dtype=dtype) | ||
|
||
def update_state(self, y_true, y_pred, sample_weight=None) -> None: | ||
if sample_weight is not None: | ||
state = k.ops.sum(sample_weight * (y_pred - y_true)) | ||
state /= k.ops.sum(sample_weight) | ||
else: | ||
state = k.ops.mean(y_pred - y_true) | ||
state = k.ops.cast(state, self.dtype) | ||
print("debug:", self.dtype) | ||
self._metric_values.assign(state) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.