Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added visualization code using saliency maps #223

Merged
merged 1 commit into from
Mar 11, 2016
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 45 additions & 26 deletions nolearn/lasagne/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from lasagne.layers import get_output
from lasagne.layers import get_output_shape
from lasagne.objectives import binary_crossentropy
import matplotlib.pyplot as plt
import numpy as np
import theano
Expand Down Expand Up @@ -176,6 +177,37 @@ def occlusion_heatmap(net, x, target, square_length=7):
return heat_array


def _plot_heat_map(net, X, figsize, get_heat_image):
if (X.ndim != 4):
raise ValueError("This function requires the input data to be of "
"shape (b, c, x, y), instead got {}".format(X.shape))

num_images = X.shape[0]
if figsize[1] is None:
figsize = (figsize[0], num_images * figsize[0] / 3)
figs, axes = plt.subplots(num_images, 3, figsize=figsize)

for ax in axes.flatten():
ax.set_xticks([])
ax.set_yticks([])
ax.axis('off')

for n in range(num_images):
heat_img = get_heat_image(net, X[n:n + 1, :, :, :], n)

ax = axes if num_images == 1 else axes[n]
img = X[n, :, :, :].mean(0)
ax[0].imshow(-img, interpolation='nearest', cmap='gray')
ax[0].set_title('image')
ax[1].imshow(-heat_img, interpolation='nearest', cmap='Reds')
ax[1].set_title('critical parts')
ax[2].imshow(-img, interpolation='nearest', cmap='gray')
ax[2].imshow(-heat_img, interpolation='nearest', cmap='Reds',
alpha=0.6)
ax[2].set_title('super-imposed')
return plt


def plot_occlusion(net, X, target, square_length=7, figsize=(9, None)):
"""Plot which parts of an image are particularly import for the
net to classify the image correctly.
Expand Down Expand Up @@ -210,33 +242,20 @@ def plot_occlusion(net, X, target, square_length=7, figsize=(9, None)):
and both images super-imposed.

"""
if (X.ndim != 4):
raise ValueError("This function requires the input data to be of "
"shape (b, c, x, y), instead got {}".format(X.shape))
return _plot_heat_map(net, X, figsize, lambda net, X, n: occlusion_heatmap(net, X, target[n], square_length))

num_images = X.shape[0]
if figsize[1] is None:
figsize = (figsize[0], num_images * figsize[0] / 3)
figs, axes = plt.subplots(num_images, 3, figsize=figsize)

for ax in axes.flatten():
ax.set_xticks([])
ax.set_yticks([])
ax.axis('off')
def saliency_map(input, output, pred, X):
score = -binary_crossentropy(output[:, pred], np.array([1])).sum()
return np.abs(T.grad(score, input).eval({input: X}))

for n in range(num_images):
heat_img = occlusion_heatmap(
net, X[n:n + 1, :, :, :], target[n], square_length
)

ax = axes if num_images == 1 else axes[n]
img = X[n, :, :, :].mean(0)
ax[0].imshow(-img, interpolation='nearest', cmap='gray')
ax[0].set_title('image')
ax[1].imshow(-heat_img, interpolation='nearest', cmap='Reds')
ax[1].set_title('critical parts')
ax[2].imshow(-img, interpolation='nearest', cmap='gray')
ax[2].imshow(-heat_img, interpolation='nearest', cmap='Reds',
alpha=0.6)
ax[2].set_title('super-imposed')
return plt
def saliency_map_net(net, X):
input = net.layers_[0].input_var
output = get_output(net.layers_[-1])
pred = output.eval({input: X}).argmax(axis=1)
return saliency_map(input, output, pred, X)[0].transpose(1, 2, 0).squeeze()


def plot_saliency(net, X, figsize=(9, None)):
return _plot_heat_map(net, X, figsize, lambda net, X, n: -saliency_map_net(net, X))