Skip to content

Commit

Permalink
Add some better examples.
Browse files Browse the repository at this point in the history
  • Loading branch information
lartpang authored Jun 22, 2022
1 parent 8957131 commit 92bf0ec
Showing 1 changed file with 99 additions and 12 deletions.
111 changes: 99 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,55 +104,142 @@ print(ssim_score_0.shape, ssim_score_1.shape)

## As A Loss

![prediction](https://user-images.githubusercontent.com/26847524/174814849-f80ec67c-5397-4ce6-bf4e-8b0aa568ed6f.png)
As you can see from the respective thresholds of the two cases below, it is easier to optimize towards MSSIM=1 than MSSIM=-1.

### Optimize towards MSSIM=1

![prediction](https://user-images.githubusercontent.com/26847524/174930091-9d7f7505-1752-423a-b7c3-d4dbfeb8d336.png)

```python
import matplotlib.pyplot as plt
import torch
from pytorch_ssim import SSIM
from skimage import data
from torch.optim import Adam
from torch import optim


original_image = data.camera() / 255
original_image = data.moon() / 255
target_image = torch.from_numpy(original_image).unsqueeze(0).unsqueeze(0).float().cuda()
predicted_image = torch.rand_like(
predicted_image = torch.zeros_like(
target_image, device=target_image.device, dtype=target_image.dtype, requires_grad=True
)
initial_image = predicted_image.clone()

ssim = SSIM().cuda()
initial_ssim_value = ssim(predicted_image, target_image)
print(f"Initial ssim: {initial_ssim_value.item():.4f}")
ssim_value = initial_ssim_value

optimizer = Adam([predicted_image], lr=0.01)
ssim_value = initial_ssim_value
optimizer = optim.Adam([predicted_image], lr=0.01)
loss_curves = []
while ssim_value < 0.95:
while ssim_value < 0.999:
ssim_out = 1 - ssim(predicted_image, target_image)
loss_curves.append(ssim_out.item())
ssim_value = 1 - ssim_out.item()
print(ssim_value)
ssim_out.backward()
optimizer.step()
optimizer.zero_grad()

fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(8, 2))
fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(8, 4))
ax = axes.ravel()

ax[0].imshow(original_image, cmap=plt.cm.gray, vmin=0, vmax=1)
ax[0].set_title("Original Image")

ax[1].imshow(initial_image.squeeze().detach().cpu().numpy(), cmap=plt.cm.gray, vmin=0, vmax=1)
ax[1].set_xlabel(f"SSIM: {initial_ssim_value:.4f}")
ax[1].set_xlabel(f"SSIM: {initial_ssim_value:.5f}")
ax[1].set_title("Initial Image")

ax[2].imshow(predicted_image.squeeze().detach().cpu().numpy(), cmap=plt.cm.gray, vmin=0, vmax=1)
ax[2].set_xlabel(f"SSIM: {ssim_value:.4f}")
ax[2].set_xlabel(f"SSIM: {ssim_value:.5f}")
ax[2].set_title("Predicted Image")

ax[3].plot(loss_curves)
ax[3].set_title("SSIM Loss Curve")

ax[4].set_title("Original Image")
ax[4].hist(original_image.ravel(), bins=256)
ax[4].ticklabel_format(axis="y", style="scientific", scilimits=(0, 0))
ax[4].set_xlabel("Pixel Intensity")

ax[5].set_title("Initial Image")
ax[5].hist(initial_image.squeeze().detach().cpu().numpy().ravel(), bins=256)
ax[5].ticklabel_format(axis="y", style="scientific", scilimits=(0, 0))
ax[5].set_xlabel("Pixel Intensity")

ax[6].set_title("Predicted Image")
ax[6].hist(predicted_image.squeeze().detach().cpu().numpy().ravel(), bins=256)
ax[6].ticklabel_format(axis="y", style="scientific", scilimits=(0, 0))
ax[6].set_xlabel("Pixel Intensity")

plt.tight_layout()
plt.savefig("prediction.png")
```

### Optimize towards MSSIM=-1

![prediction](https://user-images.githubusercontent.com/26847524/174929574-5332cab2-104f-4aab-a4e5-35e7635a793f.png)

```python
import matplotlib.pyplot as plt
import torch
from pytorch_ssim import SSIM
from skimage import data
from torch import optim

original_image = data.moon() / 255
target_image = torch.from_numpy(original_image).unsqueeze(0).unsqueeze(0).float().cuda()
predicted_image = torch.zeros_like(
target_image, device=target_image.device, dtype=target_image.dtype, requires_grad=True
)
initial_image = predicted_image.clone()

ssim = SSIM(L=original_image.max() - original_image.min()).cuda()
initial_ssim_value = ssim(predicted_image, target_image)

ssim_value = initial_ssim_value
optimizer = optim.Adam([predicted_image], lr=0.01)
loss_curves = []
while ssim_value > -0.94:
ssim_out = ssim(predicted_image, target_image)
loss_curves.append(ssim_out.item())
ssim_value = ssim_out.item()
print(ssim_value)
ssim_out.backward()
optimizer.step()
optimizer.zero_grad()

fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(8, 4))
ax = axes.ravel()

ax[0].imshow(original_image, cmap=plt.cm.gray, vmin=0, vmax=1)
ax[0].set_title("Original Image")

ax[1].imshow(initial_image.squeeze().detach().cpu().numpy(), cmap=plt.cm.gray, vmin=0, vmax=1)
ax[1].set_xlabel(f"SSIM: {initial_ssim_value:.5f}")
ax[1].set_title("Initial Image")

ax[2].imshow(predicted_image.squeeze().detach().cpu().numpy(), cmap=plt.cm.gray, vmin=0, vmax=1)
ax[2].set_xlabel(f"SSIM: {ssim_value:.5f}")
ax[2].set_title("Predicted Image")

ax[3].plot(loss_curves)
ax[3].set_title("SSIM Loss Curve")

ax[4].set_title("Original Image")
ax[4].hist(original_image.ravel(), bins=256)
ax[4].ticklabel_format(axis="y", style="scientific", scilimits=(0, 0))
ax[4].set_xlabel("Pixel Intensity")

ax[5].set_title("Initial Image")
ax[5].hist(initial_image.squeeze().detach().cpu().numpy().ravel(), bins=256)
ax[5].ticklabel_format(axis="y", style="scientific", scilimits=(0, 0))
ax[5].set_xlabel("Pixel Intensity")

ax[6].set_title("Predicted Image")
ax[6].hist(predicted_image.squeeze().detach().cpu().numpy().ravel(), bins=256)
ax[6].ticklabel_format(axis="y", style="scientific", scilimits=(0, 0))
ax[6].set_xlabel("Pixel Intensity")

plt.tight_layout()
plt.savefig("prediction.png")
```
Expand Down

0 comments on commit 92bf0ec

Please sign in to comment.