forked from NeoNeuron/2aRNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_model.py
131 lines (101 loc) · 3.98 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from model import SingleAreaRNN, TwoAreaRNN
from data import gen_data
import os
from datetime import datetime
def train_model(model, device, seed):
n_epochs = 100
n_trials_per_epoch = 100
batch_size = 20
num_batches = n_trials_per_epoch // batch_size
task_timing = [300, 1000, 900, 500]
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()
loss_list = []
for epoch in range(n_epochs):
task_timing_ = task_timing.copy()
task_timing_[2] = np.random.randint(300, 1500)
x, y, metadata = gen_data(n_trials_per_epoch, timing=task_timing_)
x = torch.from_numpy(x).to(device)
y = torch.from_numpy(y).to(device)
loss_buff = 0
for batch in range(num_batches):
batch_slice = slice(batch * batch_size, (batch + 1) * batch_size)
output = model(x[batch_slice])
loss = criterion(output, y[batch_slice])
loss_buff += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_list.append(loss_buff / num_batches)
if epoch % 10 == 0:
print(f'Epoch {epoch}, Loss {loss_list[-1]:.6f}')
return loss_list
def evaluate_model(model, device):
task_timing = [300, 1000, 900, 500]
model.eval()
with torch.no_grad():
x, y, metadata = gen_data(100, timing=task_timing)
x = torch.from_numpy(x).to(device)
outputs, hs = model(x, return_hidden=True)
outputs = outputs.cpu().numpy()
if isinstance(hs, tuple): # TwoAreaRNN
hs = tuple(h.cpu().numpy() for h in hs)
else: # SingleAreaRNN
hs = hs.cpu().numpy()
decisions = np.sign(outputs[:, -1, 1])
accuracy = np.mean(decisions == metadata['action'])
return accuracy, hs
def main():
device = 'cuda' if torch.cuda.is_available() else 'cpu'
os.makedirs('fig', exist_ok=True)
os.makedirs('model', exist_ok=True)
models_config = {
'1aRNN': (SingleAreaRNN, {'input_size': 5, 'hidden_size': 100, 'output_size': 2}),
'2aRNN': (TwoAreaRNN, {'input_size': 5, 'hidden_size': 100, 'output_size': 2})
}
all_accuracies = {model_name: [] for model_name in models_config.keys()}
for seed in range(5):
print(f"\nTraining models with seed {seed}")
torch.manual_seed(seed)
np.random.seed(seed)
plt.figure(figsize=(10, 6))
for model_name, (model_class, model_args) in models_config.items():
print(f"\nTraining {model_name}")
model = model_class(**model_args).to(device)
loss_list = train_model(model, device, seed)
accuracy, hs = evaluate_model(model, device)
all_accuracies[model_name].append(accuracy)
plt.semilogy(loss_list, label=model_name)
save_path = os.path.join('model', f'{model_name}_seed{seed}_acc{accuracy:.3f}.pt')
torch.save(model.state_dict(), save_path)
plt.title(f'Loss Curves (Seed {seed})')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig(f'fig/loss_curves_seed{seed}.png')
plt.close()
# Boxplot
plt.figure(figsize=(8, 6))
plt.boxplot([all_accuracies['1aRNN'], all_accuracies['2aRNN']],
labels=['1aRNN', '2aRNN'])
plt.title('Model Accuracies')
plt.ylabel('Accuracy')
plt.grid(True)
plt.savefig('fig/accuracy_boxplot.png')
plt.close()
with open('fig/accuracies.txt', 'w') as f:
for model_name, accs in all_accuracies.items():
f.write(f'{model_name}:\n')
for seed, acc in enumerate(accs):
f.write(f' Seed {seed}: {acc:.4f}\n')
f.write(f' Mean: {np.mean(accs):.4f}\n')
f.write(f' Std: {np.std(accs):.4f}\n\n')
if __name__ == '__main__':
main()