-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
100 lines (81 loc) · 2.72 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import wandb
from datasets import load_dataset
from transformers import PerceiverTokenizer, PerceiverForSequenceClassification, \
Trainer, TrainingArguments, DataCollatorWithPadding
from metrics import compute_metrics
from dataset import k_fold_split
def train_folds(dataset, n_folds=10):
tokenizer = PerceiverTokenizer.from_pretrained('deepmind/language-perceiver')
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding='max_length')
labels = dataset.features['label'].names
id2label = { id: label for id, label in enumerate(labels) }
label2id = { label: id for id, label in enumerate(labels) }
tokenized_dataset = dataset.map(
lambda examples: tokenizer(
examples['sentence'],
truncation=True
),
batched=True
)
dataset_splits = k_fold_split(
dataset=tokenized_dataset,
n_splits=10,
shuffle=True
)
default_training_args = {
'per_device_train_batch_size': 16,
'per_device_eval_batch_size': 16,
'num_train_epochs': 4,
'learning_rate': 2e-5,
'evaluation_strategy': 'epoch',
'save_strategy': 'epoch',
'save_total_limit': 2,
'logging_strategy': 'steps',
'logging_first_step': True,
'logging_steps': 5,
'report_to': 'wandb'
}
for current_fold, fold_data in enumerate(dataset_splits):
print(f'Starting fold {current_fold}')
train, eval = fold_data
model = PerceiverForSequenceClassification.from_pretrained(
'deepmind/language-perceiver',
num_labels=len(labels),
id2label=id2label,
label2id=label2id
)
trainer = train_model(
output_dir=f'fold_{current_fold}',
model=model,
tokenizer=tokenizer,
data_collator=data_collator,
training_args=default_training_args,
train=train,
eval=eval
)
print(f'Finished training fold {current_fold}')
def train_model(model, output_dir, tokenizer, data_collator, training_args, train, eval):
training_args = TrainingArguments(
output_dir=output_dir,
**training_args
)
trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=train,
eval_dataset=eval,
data_collator=data_collator,
compute_metrics=compute_metrics
)
trainer.train()
wandb.finish()
return trainer
if __name__ == '__main__':
wandb.login()
financial_phrasebank = load_dataset(
path='financial_phrasebank',
name='sentences_50agree',
split='train'
)
train_folds(financial_phrasebank, n_folds=10)