-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
66 lines (54 loc) · 2.23 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import os
import gc
import random
from transformer import ECGDETR
from dataset import ECGDataset
from qrs import QRS
from utils import prepare_data, create_dataset, load_ecg
if __name__ == '__main__':
ecg, annotation = load_ecg('mitbit/223')
d = ECGDataset(ecg[:12000])
m = ECGDETR(QRS.METHODS[1:])
pred = m.predict(ecg[:12000])
m.dataset.plot_ecg(30, 32)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
path = 'mitbit'
numbers = ['200', '208', '233', '100', '101', '103', '113', '115', '109', '111', '214', '118', '124', '212', '231',
'209', '220', '222', '223', '232', '106', '119'] # 207
#
import pickle
# dataset = create_dataset(numbers, path)
# with open('dataset_new.pkl', 'wb') as f:
# pickle.dump(dataset, f)
with open('dataset_new.pkl', 'rb') as f:
dataset = pickle.load(f)
batch_labels = 2000
max_items = 2200
methods = QRS.METHODS[1:] # ['dwt', 'stft', 'spectrogram', 'cwt', 'fft', 'welch']
epochs = 1000
sp = []
random.shuffle(dataset.split_ecg)
for label in QRS.LABELS:
sp.extend(dataset[label][:max_items])
random.shuffle(sp)
dataset.split_ecg = sp
for method in methods:
if not os.path.exists(f'models/{method}'):
os.makedirs(f'models/{method}')
model = ECGDETR([method])
for i in range(0, max_items, batch_labels):
if i < max_items * .9:
train_images, train_lbls = prepare_data(i, [method], dataset, batch_labels=batch_labels)
model.train_model(train_images.reshape(-1, 1, *QRS.IMG_SIZE),
train_lbls.reshape(-1, len(QRS.LABELS)),
epochs=epochs, early_stopping_patience=epochs // 10)
print(f"Batch: {i} done")
else:
train_images, train_lbls = prepare_data(i, [method], dataset, batch_labels=max_items % batch_labels)
model.evaluate(train_images.reshape(-1, 1, *QRS.IMG_SIZE),
train_lbls.reshape(-1, len(QRS.LABELS)))
del train_images, train_lbls
gc.collect()
# model.save(f'models')
model.plot_training()