Skip to content

Commit

Permalink
Merge pull request #223 from cancan101/patch-1
Browse files Browse the repository at this point in the history
Added visualization code using saliency maps
  • Loading branch information
dnouri committed Mar 11, 2016
2 parents bf4f1e7 + f35b3aa commit f62353a
Showing 1 changed file with 45 additions and 26 deletions.
71 changes: 45 additions & 26 deletions nolearn/lasagne/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from lasagne.layers import get_output
from lasagne.layers import get_output_shape
from lasagne.objectives import binary_crossentropy
import matplotlib.pyplot as plt
import numpy as np
import theano
Expand Down Expand Up @@ -176,6 +177,37 @@ def occlusion_heatmap(net, x, target, square_length=7):
return heat_array


def _plot_heat_map(net, X, figsize, get_heat_image):
if (X.ndim != 4):
raise ValueError("This function requires the input data to be of "
"shape (b, c, x, y), instead got {}".format(X.shape))

num_images = X.shape[0]
if figsize[1] is None:
figsize = (figsize[0], num_images * figsize[0] / 3)
figs, axes = plt.subplots(num_images, 3, figsize=figsize)

for ax in axes.flatten():
ax.set_xticks([])
ax.set_yticks([])
ax.axis('off')

for n in range(num_images):
heat_img = get_heat_image(net, X[n:n + 1, :, :, :], n)

ax = axes if num_images == 1 else axes[n]
img = X[n, :, :, :].mean(0)
ax[0].imshow(-img, interpolation='nearest', cmap='gray')
ax[0].set_title('image')
ax[1].imshow(-heat_img, interpolation='nearest', cmap='Reds')
ax[1].set_title('critical parts')
ax[2].imshow(-img, interpolation='nearest', cmap='gray')
ax[2].imshow(-heat_img, interpolation='nearest', cmap='Reds',
alpha=0.6)
ax[2].set_title('super-imposed')
return plt


def plot_occlusion(net, X, target, square_length=7, figsize=(9, None)):
"""Plot which parts of an image are particularly import for the
net to classify the image correctly.
Expand Down Expand Up @@ -210,33 +242,20 @@ def plot_occlusion(net, X, target, square_length=7, figsize=(9, None)):
and both images super-imposed.
"""
if (X.ndim != 4):
raise ValueError("This function requires the input data to be of "
"shape (b, c, x, y), instead got {}".format(X.shape))
return _plot_heat_map(net, X, figsize, lambda net, X, n: occlusion_heatmap(net, X, target[n], square_length))

num_images = X.shape[0]
if figsize[1] is None:
figsize = (figsize[0], num_images * figsize[0] / 3)
figs, axes = plt.subplots(num_images, 3, figsize=figsize)

for ax in axes.flatten():
ax.set_xticks([])
ax.set_yticks([])
ax.axis('off')
def saliency_map(input, output, pred, X):
score = -binary_crossentropy(output[:, pred], np.array([1])).sum()
return np.abs(T.grad(score, input).eval({input: X}))

for n in range(num_images):
heat_img = occlusion_heatmap(
net, X[n:n + 1, :, :, :], target[n], square_length
)

ax = axes if num_images == 1 else axes[n]
img = X[n, :, :, :].mean(0)
ax[0].imshow(-img, interpolation='nearest', cmap='gray')
ax[0].set_title('image')
ax[1].imshow(-heat_img, interpolation='nearest', cmap='Reds')
ax[1].set_title('critical parts')
ax[2].imshow(-img, interpolation='nearest', cmap='gray')
ax[2].imshow(-heat_img, interpolation='nearest', cmap='Reds',
alpha=0.6)
ax[2].set_title('super-imposed')
return plt
def saliency_map_net(net, X):
input = net.layers_[0].input_var
output = get_output(net.layers_[-1])
pred = output.eval({input: X}).argmax(axis=1)
return saliency_map(input, output, pred, X)[0].transpose(1, 2, 0).squeeze()


def plot_saliency(net, X, figsize=(9, None)):
return _plot_heat_map(net, X, figsize, lambda net, X, n: -saliency_map_net(net, X))

0 comments on commit f62353a

Please sign in to comment.