-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathrnn_xor_problem_general_purpose_preconditioner.py
82 lines (69 loc) · 3.51 KB
/
rnn_xor_problem_general_purpose_preconditioner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""RNN network with the classic delayed XOR problem.
"""
import matplotlib.pyplot as plt
import numpy as np
import torch
import preconditioned_stochastic_gradient_descent as psgd
device = torch.device('cpu')
batch_size, seq_len = 128, 16 # increasing sequence_length or decreasing dimension_hidden_layer will make learning harder;
dim_in, dim_hidden, dim_out = 2, 30, 1 # current setting can solve seq len 80 ~ 90 reliably without the help of momentum
def generate_train_data():
x = np.zeros([batch_size, seq_len, dim_in], dtype=np.float32)
y = np.zeros([batch_size, dim_out], dtype=np.float32)
for i in range(batch_size):
x[i, :, 0] = np.random.choice([-1.0, 1.0], seq_len)
i1 = int(np.floor(np.random.rand() * 0.1 * seq_len))
i2 = int(np.floor(np.random.rand() * 0.4 * seq_len + 0.1 * seq_len))
x[i, i1, 1] = 1.0
x[i, i2, 1] = 1.0
if x[i, i1, 0] == x[i, i2, 0]: # XOR
y[i] = -1.0 # lable 0
else:
y[i] = 1.0 # lable 1
# tranpose x to format (sequence_length, batch_size, dimension_of_input)
return [torch.tensor(np.transpose(x, [1, 0, 2])).to(device),
torch.tensor(y).to(device)]
# generate a random orthogonal matrix for recurrent matrix initialization
def get_rand_orth(dim):
temp = np.random.normal(size=[dim, dim])
q, _ = np.linalg.qr(temp)
return torch.tensor(q, dtype=torch.float32).to(device)
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.W1x = torch.nn.Parameter(0.1 * torch.randn(dim_in, dim_hidden))
self.W1h = torch.nn.Parameter(get_rand_orth(dim_hidden))
self.b1 = torch.nn.Parameter(torch.zeros(dim_hidden))
self.W2 = torch.nn.Parameter(0.1 * torch.randn(dim_hidden, dim_out))
self.b2 = torch.nn.Parameter(torch.zeros([]))
def forward(self, xs):
h = torch.zeros(batch_size, dim_hidden, device=device)
for x in torch.unbind(xs):
h = torch.tanh(x @ self.W1x + h @ self.W1h + self.b1)
return h @ self.W2 + self.b2
model = Model().to(device)
# choose a flavor of PSGD
# opt = psgd.Newton(model.parameters(), preconditioner_init_scale=None, lr_params=0.01, lr_preconditioner=0.01, grad_clip_max_norm=1.0)
opt = psgd.LRA(model.parameters(), preconditioner_init_scale=None, lr_params=0.01, lr_preconditioner=0.01, grad_clip_max_norm=1.0)
# opt = psgd.XMat(model.parameters(), preconditioner_init_scale=None, lr_params=0.01, lr_preconditioner=0.01, grad_clip_max_norm=1.0)
def train_loss(xy_pair): # logistic loss
return -torch.mean(torch.log(torch.sigmoid(xy_pair[1] * model(xy_pair[0]))))
Losses = []
for num_iter in range(100000):
train_data = generate_train_data()
# rng_state = torch.get_rng_state()
# rng_cuda_state = torch.cuda.get_rng_state()
def closure():
# If exact_hessian_vector_product=False and rng is used inside closure,
# make sure rng starts from the same state; otherwise, doesn't matter.
# torch.set_rng_state(rng_state)
# torch.cuda.set_rng_state(rng_cuda_state)
return train_loss(train_data) # return a loss
# return [train_loss(train_data),] # or return a list with the 1st one being loss
loss = opt.step(closure)
Losses.append(loss.item())
print('Iteration: {}; loss: {}'.format(num_iter, Losses[-1]))
if Losses[-1] < 0.1:
print('Deemed to be successful and ends training')
break
plt.plot(Losses)