-
Notifications
You must be signed in to change notification settings - Fork 4
/
datasets.py
132 lines (104 loc) · 4.3 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import logging
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch
logger = logging.getLogger(__name__)
NUM_CLASSES = {
'trec': 6,
'imdb': 2,
'agnews': 4
}
class DataToDataset(Dataset):
def __init__(self, data):
self.labels, self.texts = data.values[:, 0], data.values[:, 1]
def __len__(self):
return len(self.labels)
def __getitem__(self,index):
return self.texts[index], self.labels[index]
def load_dataset(data_path, dataset_name):
extension = data_path.split(".")[-1]
assert extension == 'csv'
data = pd.read_csv(data_path, header=None)
if dataset_name in NUM_CLASSES:
num_classes = NUM_CLASSES[dataset_name]
else:
num_classes = max(data.values[:, 0]) + 1
logger.info('num_classes is %d', num_classes)
return DataToDataset(data), num_classes
class SelfMixDataset(Dataset):
def __init__(self, data_args, dataset, tokenizer, mode, pred=[], probability=[]):
self.data_args = data_args
self.labels = dataset.labels
self.inputs = dataset.texts
self.mode = mode
self.tokenizer = tokenizer
if self.mode == "labeled":
pred_idx = pred.nonzero()[0]
self.inputs = [self.inputs[idx] for idx in pred_idx]
self.labels = self.labels[pred_idx]
self.prob = [probability[idx] for idx in pred_idx]
self.pred_idx = pred_idx
elif self.mode == "unlabeled":
pred_idx = (1 - pred).nonzero()[0]
self.inputs = [self.inputs[idx] for idx in pred_idx]
self.labels = self.labels[pred_idx]
self.pred_idx = pred_idx
def __len__(self):
return len(self.inputs)
def get_tokenized(self, text):
tokens = self.tokenizer(text, padding='max_length', truncation=True,
max_length=self.data_args.max_sentence_len, return_tensors='pt')
for item in tokens:
tokens[item] = tokens[item].squeeze()
return tokens['input_ids'].squeeze(), tokens['attention_mask'].squeeze()
def __getitem__(self, index):
text = self.inputs[index]
input_id, att_mask = self.get_tokenized(text)
if self.mode == 'labeled':
return input_id, att_mask, self.labels[index], self.prob[index], self.pred_idx[index]
elif self.mode == 'unlabeled':
return input_id, att_mask, self.pred_idx[index]
elif self.mode == 'all':
return input_id, att_mask, self.labels[index], index
class SelfMixData:
def __init__(self, data_args, datasets, tokenizer):
self.data_args = data_args
self.datasets = datasets
self.tokenizer = tokenizer
def run(self, mode, pred=[], prob=[]):
if mode == "all":
all_dataset = SelfMixDataset(
data_args=self.data_args,
dataset=self.datasets,
tokenizer=self.tokenizer,
mode="all")
all_loader = DataLoader(
dataset=all_dataset,
batch_size=self.data_args.batch_size,
shuffle=True,
num_workers=2)
return all_loader
if mode == "train":
labeled_dataset = SelfMixDataset(
data_args=self.data_args,
dataset=self.datasets,
tokenizer=self.tokenizer,
mode="labeled",
pred=pred, probability=prob)
labeled_trainloader = DataLoader(
dataset=labeled_dataset,
batch_size=self.data_args.batch_size_mix,
shuffle=True,
num_workers=2)
unlabeled_dataset = SelfMixDataset(
data_args=self.data_args,
dataset=self.datasets,
tokenizer=self.tokenizer,
mode="unlabeled",
pred=pred)
unlabeled_trainloader = DataLoader(
dataset=unlabeled_dataset,
batch_size=self.data_args.batch_size_mix,
shuffle=True,
num_workers=2)
return labeled_trainloader, unlabeled_trainloader