-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathqrs.py
90 lines (79 loc) · 3.46 KB
/
qrs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gc
import pywt
import torch
import numpy as np
from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
class QRS:
LABELS = ['N', 'A', 'V', 'L', 'R']
METHODS = ['ecg', 'cgau4', 'cgau4_gray', 'cgau4_fft2', 'dwt']
IMG_SIZE = (255, 255, 3)
def __init__(self, ecg: any, label: any = None, fs: int = 360, peaks=None, start=0):
self.labels = None
self.label = None
self.confidence = None
self.ecg = ecg
self.start = start
if peaks is not None:
self.p, self.q, self.r, self.s, self.t = peaks
self.set_label(label)
self.fs = fs
self.fft2 = np.fft.fft(np.fft.fft(ecg))
self.images = {} # self._images()
def __str__(self):
# label with max confidence
label = max(self.labels, key=self.labels.get)
return f'{label} {self.labels[label] * 100:.0f}%'
def set_label(self, label: any):
if label is None:
self.labels = {l: 0. for l in self.LABELS}
elif isinstance(label, str):
self.labels = {l: 1. if label.upper() == l else 0. for l in self.LABELS}
elif isinstance(label, np.ndarray) or isinstance(label, torch.Tensor):
self.labels = {l: float(conf) for conf, l in zip(label, self.LABELS)}
else:
raise ValueError(f'Unknown label type {type(label)}')
# set label with max confidence
self.label = max(self.labels, key=self.labels.get)
self.confidence = self.labels[self.label]
def get_images(self, methods=None):
if methods is None:
methods = self.METHODS[1:]
images = {}
for method in methods:
if method == 'ecg':
images['ecg'] = self._ax_to_img(plt.plot, self.ecg)
elif method == 'dwt':
coeffs = np.abs(pywt.dwt(self.ecg, 'db38')[0].reshape(1, -1))
images['dwt'] = self._ax_to_img(plt.imshow, coeffs, cmap='jet', aspect='auto')
elif method == 'cgau4':
coeffs = np.abs(pywt.cwt(self.ecg, np.arange(1, len(self.ecg) // 2), 'cgau4')[0].reshape(1, -1))
images['cgau4'] = self._ax_to_img(plt.imshow, coeffs, cmap='jet', aspect='auto')
elif method == 'cgau4_gray':
coeffs = np.abs(pywt.cwt(self.ecg, np.arange(1, len(self.ecg) // 2), 'cgau4')[0].reshape(1, -1))
images['cgau4_gray'] = self._ax_to_img(plt.imshow, coeffs, cmap='gray', aspect='auto')
elif method == 'cgau4_fft2':
coeffs = np.abs(pywt.cwt(self.fft2, np.arange(1, len(self.ecg) // 2), 'cgau4')[0].reshape(1, -1))
images['cgau4_fft2'] = self._ax_to_img(plt.imshow, coeffs, cmap='jet', aspect='auto')
else:
raise ValueError(f'Unknown method {method}')
return images
@staticmethod
def _ax_to_img(plot_func, *args, **kwargs):
plt.figure(figsize=(8, 6))
plot_func(*args, **kwargs)
plt.axis('off')
plt.tight_layout(pad=0)
with BytesIO() as buffer:
plt.savefig(buffer, bbox_inches='tight', pad_inches=0, format='png')
plt.close()
buffer.seek(0)
image = np.array(Image.open(buffer).resize(QRS.IMG_SIZE[:2]))[..., :3]
image = (image - image.min()) / (image.max() - image.min())
del plot_func, args, kwargs
gc.collect()
return torch.tensor(image).float()
def plot(self):
plt.plot(self.ecg)
plt.show()