-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
62 lines (54 loc) · 1.91 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import pickle
import numpy as np
import mne
import gc
import torch
from dataset import ECGDataset
def load_ecg(path: str):
# data txt or csv or edf
if path.endswith('.txt'):
return np.loadtxt(path)
elif path.endswith('.csv'):
return np.genfromtxt(path, delimiter=',')
elif path.endswith('.edf'):
return mne.io.read_raw_edf(path).get_data()[1]
else:
import wfdb
# get path without extension
path = path.split('.')[0]
record = wfdb.rdrecord(path)
annotation = wfdb.rdann(path, 'atr')
return record.p_signal[:, 0], annotation
def create_dataset(numbers: list, path: str):
# create database
dataset = ECGDataset(*load_ecg(os.path.join(path, numbers[0])))
for number in numbers[1:]:
ecg, annotation = load_ecg(os.path.join(path, number))
print(f'ECG: {number} done')
dataset.append(ecg, annotation)
del ecg, annotation
gc.collect()
dataset.ecg = np.array([])
return dataset
def prepare_data(i, methods, dataset, labels=None, batch_labels=100):
if labels is None:
labels = ['N', 'A', 'V', 'L', 'R']
train_images = torch.tensor([])
train_lbls = torch.tensor([])
for label in labels:
try:
qrs = dataset[label][i:i + batch_labels]
except IndexError:
qrs = dataset[label][i:]
images = dataset.extract_images(methods, qrs)
# torch to categorical
train_labels = torch.tensor([labels.index(label) for _ in range(len(images))])
train_labels = torch.nn.functional.one_hot(train_labels, num_classes=len(labels))
train_labels = train_labels.float()
train_lbls = torch.cat((train_lbls, train_labels))
train_images = torch.cat((train_images, images))
print(f"Label: {label} {len(images)}")
del images, train_labels
gc.collect()
return train_images, train_lbls