forked from techshot25/Autoencoders
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsimple-autoencoder.py
181 lines (122 loc) · 4.36 KB
/
simple-autoencoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
#!/usr/bin/env python
# coding: utf-8
# ## Autoencoders
# #### By Ali Shannon
#
# This simple code shows you how to make an autoencoder using Pytorch. The idea is to bring down the number of dimensions (or reduce the feature space) using neural networks.
#
# The idea is simple, let the neural network learn how to make the encoder and the decoder using the feature space as both the input and the output of the network.
# In[8]:
import torch
from torch import nn, optim
import numpy as np
from matplotlib import pyplot as plt
import mpl_toolkits.mplot3d.axes3d as p3
from sklearn.datasets import make_swiss_roll
from sklearn.preprocessing import MinMaxScaler
# Here I am using the swiss roll example and reduce it from 3D to 2D
# In[10]:
device = ('cuda' if torch.cuda.is_available() else 'cpu')
n_samples = 1500
noise = 0.05
X, colors = make_swiss_roll(n_samples, noise)
X = MinMaxScaler().fit_transform(X)
fig = plt.figure()
ax = p3.Axes3D(fig)
ax.scatter(X[:,0], X[:,1], X[:,2], c=colors, cmap=plt.cm.jet)
plt.title('Swiss roll')
plt.show()
# In[38]:
x = torch.from_numpy(X).to(device)
class Autoencoder(nn.Module):
"""Makes the main denoising auto
Parameters
----------
in_shape [int] : input shape
enc_shape [int] : desired encoded shape
"""
def __init__(self, in_shape, enc_shape):
super(Autoencoder, self).__init__()
self.encode = nn.Sequential(
nn.Linear(in_shape, 128),
nn.ReLU(True),
nn.Dropout(0.2),
nn.Linear(128, 64),
nn.ReLU(True),
nn.Dropout(0.2),
nn.Linear(64, enc_shape),
)
self.decode = nn.Sequential(
nn.BatchNorm1d(enc_shape),
nn.Linear(enc_shape, 64),
nn.ReLU(True),
nn.Dropout(0.2),
nn.Linear(64, 128),
nn.ReLU(True),
nn.Dropout(0.2),
nn.Linear(128, in_shape)
)
def forward(self, x):
x = self.encode(x)
x = self.decode(x)
return x
encoder = Autoencoder(in_shape=3, enc_shape=2).double().to(device)
error = nn.MSELoss()
optimizer = optim.Adam(encoder.parameters())
# In[39]:
def train(model, error, optimizer, n_epochs, x):
model.train()
for epoch in range(1, n_epochs + 1):
optimizer.zero_grad()
output = model(x)
loss = error(output, x)
loss.backward()
optimizer.step()
if epoch % int(0.1*n_epochs) == 0:
print(f'epoch {epoch} \t Loss: {loss.item():.4g}')
# You can rerun this function or just increase the number of epochs. Dropout was added for denoising, otherwise it will be very sensitive to input variations.
# In[42]:
train(encoder, error, optimizer, 5000, x)
# In[43]:
with torch.no_grad():
encoded = encoder.encode(x)
decoded = encoder.decode(encoded)
mse = error(decoded, x).item()
enc = encoded.cpu().detach().numpy()
dec = decoded.cpu().detach().numpy()
# In[44]:
plt.scatter(enc[:, 0], enc[:, 1], c=colors, cmap=plt.cm.jet)
plt.title('Encoded Swiss Roll')
plt.show()
# In[45]:
fig = plt.figure(figsize=(15,6))
ax = fig.add_subplot(121, projection='3d')
ax.scatter(X[:,0], X[:,1], X[:,2], c=colors, cmap=plt.cm.jet)
plt.title('Original Swiss roll')
ax = fig.add_subplot(122, projection='3d')
ax.scatter(dec[:,0], dec[:,1], dec[:,2], c=colors, cmap=plt.cm.jet)
plt.title('Decoded Swiss roll')
plt.show()
print(f'Root mean squared error: {np.sqrt(mse):.4g}')
# Obviously there are some losses in variance due to the dimensionality reduction but this reconstruction is quite interesting. This is how the model reacts to another roll.
# In[118]:
n_samples = 2500
noise = 0.1
X, colors = make_swiss_roll(n_samples, noise)
X = MinMaxScaler().fit_transform(X)
x = torch.from_numpy(X).to(device)
with torch.no_grad():
encoded = encoder.encode(x)
decoded = encoder.decode(encoded)
mse = error(decoded, x).item()
enc = encoded.cpu().detach().numpy()
dec = decoded.cpu().detach().numpy()
fig = plt.figure(figsize=(15,6))
ax = fig.add_subplot(121, projection='3d')
ax.scatter(X[:,0], X[:,1], X[:,2], c=colors, cmap=plt.cm.jet)
plt.title('New Swiss roll')
ax = fig.add_subplot(122, projection='3d')
ax.scatter(dec[:,0], dec[:,1], dec[:,2], c=colors, cmap=plt.cm.jet)
plt.title('Decoded Swiss roll')
plt.show()
print(f'Root mean squared error: {np.sqrt(mse):.4g}')