Skip to content

Commit

Permalink
Add PSNR API (#19616)
Browse files Browse the repository at this point in the history
* PSNR

* Fix
  • Loading branch information
IMvision12 authored Apr 25, 2024
1 parent 6524242 commit 74df926
Show file tree
Hide file tree
Showing 10 changed files with 167 additions and 0 deletions.
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from keras.src.ops.nn import multi_hot
from keras.src.ops.nn import normalize
from keras.src.ops.nn import one_hot
from keras.src.ops.nn import psnr
from keras.src.ops.nn import relu
from keras.src.ops.nn import relu6
from keras.src.ops.nn import selu
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from keras.src.ops.nn import multi_hot
from keras.src.ops.nn import normalize
from keras.src.ops.nn import one_hot
from keras.src.ops.nn import psnr
from keras.src.ops.nn import relu
from keras.src.ops.nn import relu6
from keras.src.ops.nn import selu
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from keras.src.ops.nn import multi_hot
from keras.src.ops.nn import normalize
from keras.src.ops.nn import one_hot
from keras.src.ops.nn import psnr
from keras.src.ops.nn import relu
from keras.src.ops.nn import relu6
from keras.src.ops.nn import selu
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from keras.src.ops.nn import multi_hot
from keras.src.ops.nn import normalize
from keras.src.ops.nn import one_hot
from keras.src.ops.nn import psnr
from keras.src.ops.nn import relu
from keras.src.ops.nn import relu6
from keras.src.ops.nn import selu
Expand Down
13 changes: 13 additions & 0 deletions keras/src/backend/jax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,3 +926,16 @@ def ctc_decode(
f"Invalid strategy {strategy}. Supported values are "
"'greedy' and 'beam_search'."
)


def psnr(x1, x2, max_val):
if x1.shape != x2.shape:
raise ValueError(
f"Input shapes {x1.shape} and {x2.shape} must "
"match for PSNR calculation. "
)

max_val = convert_to_tensor(max_val, dtype=x2.dtype)
mse = jnp.mean(jnp.square(x1 - x2))
psnr = 20 * jnp.log10(max_val) - 10 * jnp.log10(mse)
return psnr
13 changes: 13 additions & 0 deletions keras/src/backend/numpy/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,3 +967,16 @@ def ctc_decode(
f"Invalid strategy {strategy}. Supported values are "
"'greedy' and 'beam_search'."
)


def psnr(x1, x2, max_val):
if x1.shape != x2.shape:
raise ValueError(
f"Input shapes {x1.shape} and {x2.shape} must "
"match for PSNR calculation. "
)

max_val = convert_to_tensor(max_val, dtype=x2.dtype)
mse = np.mean(np.square(x1 - x2))
psnr = 20 * np.log10(max_val) - 10 * np.log10(mse)
return psnr
15 changes: 15 additions & 0 deletions keras/src/backend/tensorflow/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,3 +846,18 @@ def ctc_decode(
decoded_dense = tf.stack(decoded_dense, axis=0)
decoded_dense = tf.cast(decoded_dense, "int32")
return decoded_dense, scores


def psnr(x1, x2, max_val):
from keras.src.backend.tensorflow.numpy import log10

if x1.shape != x2.shape:
raise ValueError(
f"Input shapes {x1.shape} and {x2.shape} must "
"match for PSNR calculation. "
)

max_val = convert_to_tensor(max_val, dtype=x2.dtype)
mse = tf.reduce_mean(tf.square(x1 - x2))
psnr = 20 * log10(max_val) - 10 * log10(mse)
return psnr
17 changes: 17 additions & 0 deletions keras/src/backend/torch/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,3 +848,20 @@ def ctc_decode(
f"Invalid strategy {strategy}. Supported values are "
"'greedy' and 'beam_search'."
)


def psnr(x1, x2, max_val):
if x1.shape != x2.shape:
raise ValueError(
f"Input shapes {x1.shape} and {x2.shape} must "
"match for PSNR calculation. "
)

x1, x2 = (
convert_to_tensor(x1),
convert_to_tensor(x2),
)
max_val = convert_to_tensor(max_val, dtype=x1.dtype)
mse = torch.mean((x1 - x2) ** 2)
psnr = 20 * torch.log10(max_val) - 10 * torch.log10(mse)
return psnr
74 changes: 74 additions & 0 deletions keras/src/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2042,3 +2042,77 @@ def _normalize(x, axis=-1, order=2):
norm = backend.linalg.norm(x, ord=order, axis=axis, keepdims=True)
denom = backend.numpy.maximum(norm, epsilon)
return backend.numpy.divide(x, denom)


class PSNR(Operation):
def __init__(
self,
max_val,
):
super().__init__()
self.max_val = max_val

def call(self, x1, x2):
return backend.nn.psnr(
x1=x1,
x2=x2,
max_val=self.max_val,
)

def compute_output_spec(self, x1, x2):
if len(x1.shape) != len(x2.shape):
raise ValueError("Inputs must have the same rank")

return KerasTensor(shape=())


@keras_export(
[
"keras.ops.psnr",
"keras.ops.nn.psnr",
]
)
def psnr(
x1,
x2,
max_val,
):
"""Peak Signal-to-Noise Ratio (PSNR) calculation.
This function calculates the Peak Signal-to-Noise Ratio between two signals,
`x1` and `x2`. PSNR is a measure of the quality of a reconstructed signal.
The higher the PSNR, the closer the reconstructed signal is to the original
signal.
Args:
x1: The first input signal.
x2: The second input signal. Must have the same shape as `x1`.
max_val: The maximum possible value in the signals.
Returns:
float: The PSNR value between `x1` and `x2`.
Examples:
>>> import numpy as np
>>> from keras import ops
>>> x = np.random.random((2, 4, 4, 3))
>>> y = np.random.random((2, 4, 4, 3))
>>> max_val = 1.0
>>> psnr_value = ops.nn.psnr(x, y, max_val)
>>> psnr_value
20.0
"""
if any_symbolic_tensors(
(
x1,
x2,
)
):
return PSNR(
max_val,
).symbolic_call(x1, x2)
return backend.nn.psnr(
x1,
x2,
max_val,
)
31 changes: 31 additions & 0 deletions keras/src/ops/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,12 @@ def test_normalize(self):
x = KerasTensor([None, 2, 3])
self.assertEqual(knn.normalize(x).shape, (None, 2, 3))

def test_psnr(self):
x1 = KerasTensor([None, 2, 3])
x2 = KerasTensor([None, 5, 6])
out = knn.psnr(x1, x2, max_val=224)
self.assertEqual(out.shape, ())


class NNOpsStaticShapeTest(testing.TestCase):
def test_relu(self):
Expand Down Expand Up @@ -1114,6 +1120,12 @@ def test_normalize(self):
x = KerasTensor([1, 2, 3])
self.assertEqual(knn.normalize(x).shape, (1, 2, 3))

def test_psnr(self):
x1 = KerasTensor([1, 2, 3])
x2 = KerasTensor([5, 6, 7])
out = knn.psnr(x1, x2, max_val=224)
self.assertEqual(out.shape, ())


class NNOpsCorrectnessTest(testing.TestCase, parameterized.TestCase):
def test_relu(self):
Expand Down Expand Up @@ -2032,6 +2044,25 @@ def test_normalize(self):
],
)

def test_psnr(self):
x1 = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
x2 = np.array([[0.2, 0.2, 0.3], [0.4, 0.6, 0.6]])
max_val = 1.0
expected_psnr_1 = 20 * np.log10(max_val) - 10 * np.log10(
np.mean(np.square(x1 - x2))
)
psnr_1 = knn.psnr(x1, x2, max_val)
self.assertAlmostEqual(psnr_1, expected_psnr_1)

x3 = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
x4 = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
max_val = 1.0
expected_psnr_2 = 20 * np.log10(max_val) - 10 * np.log10(
np.mean(np.square(x3 - x4))
)
psnr_2 = knn.psnr(x3, x4, max_val)
self.assertAlmostEqual(psnr_2, expected_psnr_2)


class NNOpsDtypeTest(testing.TestCase, parameterized.TestCase):
"""Test the dtype to verify that the behavior matches JAX."""
Expand Down

0 comments on commit 74df926

Please sign in to comment.