-
Notifications
You must be signed in to change notification settings - Fork 71
/
Copy pathdataset.py
97 lines (82 loc) · 2.55 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# -*- coding: utf-8 -*-
'''
@time: 2019/9/8 19:47
@ author: javis
'''
import pywt, os, copy
import torch
import numpy as np
import pandas as pd
from config import config
from torch.utils.data import Dataset
from sklearn.preprocessing import scale
from scipy import signal
def resample(sig, target_point_num=None):
'''
对原始信号进行重采样
:param sig: 原始信号
:param target_point_num:目标型号点数
:return: 重采样的信号
'''
sig = signal.resample(sig, target_point_num) if target_point_num else sig
return sig
def scaling(X, sigma=0.1):
scalingFactor = np.random.normal(loc=1.0, scale=sigma, size=(1, X.shape[1]))
myNoise = np.matmul(np.ones((X.shape[0], 1)), scalingFactor)
return X * myNoise
def verflip(sig):
'''
信号竖直翻转
:param sig:
:return:
'''
return sig[::-1, :]
def shift(sig, interval=20):
'''
上下平移
:param sig:
:return:
'''
for col in range(sig.shape[1]):
offset = np.random.choice(range(-interval, interval))
sig[:, col] += offset
return sig
def transform(sig, train=False):
# 前置不可或缺的步骤
sig = resample(sig, config.target_point_num)
# # 数据增强
if train:
if np.random.randn() > 0.5: sig = scaling(sig)
if np.random.randn() > 0.5: sig = verflip(sig)
if np.random.randn() > 0.5: sig = shift(sig)
# 后置不可或缺的步骤
sig = sig.transpose()
sig = torch.tensor(sig.copy(), dtype=torch.float)
return sig
class ECGDataset(Dataset):
"""
A generic data loader where the samples are arranged in this way:
dd = {'train': train, 'val': val, "idx2name": idx2name, 'file2idx': file2idx}
"""
def __init__(self, data_path, train=True):
super(ECGDataset, self).__init__()
dd = torch.load(config.train_data)
self.train = train
self.data = dd['train'] if train else dd['val']
self.idx2name = dd['idx2name']
self.file2idx = dd['file2idx']
self.wc = 1. / np.log(dd['wc'])
def __getitem__(self, index):
fid = self.data[index]
file_path = os.path.join(config.train_dir, fid)
df = pd.read_csv(file_path, sep=' ').values
x = transform(df, self.train)
target = np.zeros(config.num_classes)
target[self.file2idx[fid]] = 1
target = torch.tensor(target, dtype=torch.float32)
return x, target
def __len__(self):
return len(self.data)
if __name__ == '__main__':
d = ECGDataset(config.train_data)
print(d[0])