-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnnplot.py
121 lines (106 loc) · 4.16 KB
/
nnplot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# -*- coding: utf-8 -*-
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import mlab
def plotPost2D(mdn, y,
rangex = [0, 1], rangey = [0, 1],
deltax = 0.01, deltay = 0.01,
true_model = None):
M = mdn.M
alpha, sigma, mu = mdn.getMixtureParams(y)
print 'mu: ' + str(mu)
print 'sigma: ' + str(sigma)
print 'true value: ' + str(true_model)
xlin = np.arange(rangex[0], rangex[1], deltax)
ylin = np.arange(rangey[0], rangey[1], deltay)
[XLIN, YLIN] = np.meshgrid(xlin, ylin)
phi = np.zeros([M,ylin.shape[0], xlin.shape[0]])
P = np.zeros([ylin.shape[0], xlin.shape[0]])
for k in range(M):
phi[k,:,:] = mlab.bivariate_normal(XLIN, YLIN, np.sqrt(sigma[k]), np.sqrt(sigma[k]), mu[k,0], mu[k,1])
P = P + phi[k,:,:] * alpha[k]
plt.imshow(P, #interpolation='bilinear',
##cmap=cm.gray,
origin='lower',
extent=[rangex[0],rangex[1],
rangey[0],rangey[1]]
)
#plt.contour(XLIN, YLIN, P,
#levels = [0, 1.0/np.exp(1)]
# )
#plt.scatter(true_model[0],true_model[1],marker='^', c="r")
if not true_model == None:
plt.axvline(true_model[0], c = 'r')
plt.axhline(true_model[1], c = 'r')
def plotPost1D(mdn, y, rangex = [0, 1], deltax = 0.01, true_model = None):
alpha, sigma, mu = mdn.getMixtureParams(y)
xlin = np.arange(rangex[0], rangex[1], deltax)
phi = np.zeros([mdn.M,xlin.shape[0]])
P = np.zeros([xlin.shape[0]])
for k in range(mdn.M):
phi[k, :] = (1.0 / (2*np.pi*sigma[k])**(0.5)) * np.exp(- 1.0 * (xlin-mu[k,0])**2 / (2 * sigma[k]))
P = P + phi[k, :] * alpha[k]
#import pdb; pdb.set_trace()
plt.plot(xlin, P)
if true_model != None:
plt.axvline(true_model, c = 'r')
def plotPostCond(mdn, x, t):
y = mdn.forward(x)
alpha, sigma, mu = mdn.getMixtureParams(y)
N = t.shape[0]
phi = np.zeros([mdn.M, N, N])
P = np.zeros([N, N])
T = np.tile(t, [N,1])
for k in range(mdn.M):
SIGMA = np.tile(sigma[k,:], [N, 1]).T
MU = np.tile(mu[k,0,:], [N, 1]).T
phi[k,:,:] = (1.0 / (2 * np.pi * SIGMA)**(0.5)) * np.exp(- 1.0 * (T-MU)**2 / (2 * SIGMA))
P = P + phi[k,:,:] * np.tile(alpha[k, :], [N, 1]).T
plt.imshow(P, #interpolation='bilinear',
#cmap=cm.gray,
origin='lower',
extent=[min(t),max(t),
min(t),max(t)]
)
#X, Y = np.meshgrid(t, t)
#plt.contour(X, Y, P,
#levels = [0, 1.0/np.exp(1)]
#)
def plotPostMap(mdn, x, t):
y = mdn.forward(x)
alpha, sigma, mu = mdn.getMixtureParams(y)
def plotModelVsTrue(mdn, y, t, thres = 0.7, dim = 0):
alpha, sigma, mu = mdn.getMixtureParams(y)
# find most important kernels
# find most probable kernel
idx_max = np.argmax(alpha, axis = 0)
mu_max = np.zeros([y.shape[0], mdn.c])
sigma_max = np.zeros([y.shape[0]])
alpha_max = np.zeros([y.shape[0]])
for n in range(y.shape[0]):
mu_max[n, :] = mu[idx_max[n],:,n]
sigma_max[n] = sigma[idx_max[n], n]
alpha_max = alpha[idx_max[n], n]
if np.any(alpha_max <= thres):
print 'Warning, not all mixture coefficients are above threshold.'
plt.scatter(mu_max[:,dim], t)
plt.xlim([min(t), max(t)])
plt.ylim([min(t), max(t)])
plt.xlabel('prediction')
plt.ylabel('true value')
def plotTextBox(s):
left, width = .25, .5
bottom, height = .25, .5
right = left + width
top = bottom + height
ax = plt.gca()
p = plt.Rectangle((left, bottom), width, height,
fill=False,
)
p.set_transform(ax.transAxes)
p.set_clip_on(False)
ax.add_patch(p)
ax.text(left, top, s,
horizontalalignment='center',
verticalalignment='top',
transform=ax.transAxes)