-
Notifications
You must be signed in to change notification settings - Fork 259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added visualization code using saliency maps #223
Conversation
Let's try to bring together the common parts between |
I've been using this code for my recent Kaggle stuff, maybe there is something in it that we would want to use here: class GradientPlot(object):
def __init__(self, net):
loss_function = net.objective_loss_function
output = get_output(net._output_layer, deterministic=True)
Xs = net.layers_[0].input_var
ys = net.y_tensor_type()
loss = T.mean(loss_function(output, ys))
grad = T.grad(loss, Xs)
self.grad_func = theano.function([Xs, ys], grad)
def plot(self, X, y, exponent=0.3):
grads = self.grad_func(X, y)
n = X.shape[0]
figs, axes = plt.subplots(n, 3, figsize=(12, n * 4))
[axes[i, j].set_axis_off() for i in range(n) for j in range(3)]
for i, (xi, grad) in enumerate(zip(X, grads)):
axes[i, 0].imshow(xi.squeeze(), cmap='gray')
saliency = grad[0]
saliency /= saliency.max()
saliency = saliency ** exponent
axes[i, 1].imshow(-saliency, cmap='gray')
axes[i, 2].imshow(xi.squeeze(), cmap='gray')
alpha = saliency.reshape(list(saliency.shape) + [-1])
alpha = np.concatenate(
(alpha, np.zeros_like(alpha), np.zeros_like(alpha), alpha),
axis=2)
axes[i, 2].imshow(alpha, cmap='Reds')
return plt.gcf() I made it a class so that we only need to compile the Theano function once, but I'm not sure whether it's really needed. Maybe you can give it a spin and see whether you like parts of it better than the suggested implementation. PS: Why is sometimes the negative plotted? When an image does not consist of 0..255 pixel values, matplotlib tries to infer what should be dark or light (at least that's my impression). Sometimes it gets it wrong and you have to take the negative image. There are probably better solutions but it worked for me. |
@BenjaminBossan @cancan101 Can we try to remove the |
Would be cool to update the notebook and add the saliency map example once we figured out that the sign should be removed. |
ax[2].set_title('super-imposed') | ||
return plt | ||
def saliency_map_net(net, X): | ||
input = net.layers_.values()[0].input_var |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use net.layers_[0]
Added visualization code using saliency maps
No description provided.