-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrecurrent_neural_network.py
65 lines (52 loc) · 2.14 KB
/
recurrent_neural_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import numpy as np
import matplotlib.pyplot as plt
from scratch_ml.utils import to_categorical, train_test_split, CrossEntropy, accuracy_score
from scratch_ml.deep_learning import NeuralNetwork
from scratch_ml.deep_learning.layers import RNN, Activation
from scratch_ml.deep_learning.optimizers import Adam
def main():
def gen_mult_ser(nums):
"""Method which generates multiplication series."""
X = np.zeros([nums, 10, 61], dtype=float)
y = np.zeros([nums, 10, 61], dtype=float)
for i in range(nums):
start = np.random.randint(2, 7)
mult_ser = np.linspace(start, start*10, num=10, dtype=int)
X[i] = to_categorical(mult_ser, n_col=61)
y[i] = np.roll(X[i], -1, axis=0)
y[:, -1, 1] = 1
return X, y
X, y = gen_mult_ser(3000)
X_train, X_test, y_train, y_test = train_test_split(X, y)
optimizer = Adam()
Model = NeuralNetwork(optimizer=optimizer, loss=CrossEntropy)
Model.add(RNN(10, activation="tanh", bptt=5, input_shape=(10, 61)))
Model.add(Activation('softmax'))
Model.summary("RNN")
tmp_X = np.argmax(X_train[0], axis=1)
tmp_y = np.argmax(y_train[0], axis=1)
print("Number Series Problem:")
print("X = [" + " ".join(tmp_X.astype("str")) + "]")
print("y = [" + " ".join(tmp_y.astype("str")) + "]")
print()
train_err, _ = Model.fit(X_train, y_train, n_epochs=500, batch_size=512)
y_pred = np.argmax(Model.predict(X_test), axis=2)
y_test = np.argmax(y_test, axis=2)
print("Results:")
for i in range(5):
tmp_X = np.argmax(X_test[i], axis=1)
tmp_y1 = y_test[i]
tmp_y2 = y_pred[i]
print("X = [" + " ".join(tmp_X.astype("str")) + "]")
print("y_true = [" + " ".join(tmp_y1.astype("str")) + "]")
print("y_pred = [" + " ".join(tmp_y2.astype("str")) + "]")
print()
accuracy = np.mean(accuracy_score(y_test, y_pred))
print("Accuracy:", accuracy)
plt.plot(range(500), train_err, label="Training Error")
plt.title("Error Plot")
plt.ylabel("Training Error")
plt.xlabel("Iterations")
plt.show()
if __name__ == "__main__":
main()