-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathValidation_model.py
57 lines (49 loc) · 2.48 KB
/
Validation_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
## Author : Sandeep Ramachandra, sandeep.ramachandra@student.uni-siegen.de
## Description : Python file containing pytorch lightning compatible class for training the validation classifiers used
## in this work (both LSTM and transformer based networks)
import pytorch_lightning as pl
# from pytorch_lightning.metrics.functional import f1
from pytorch_lightning.metrics import F1
import torch
import torch.nn as nn
class Net(pl.LightningModule):
def __init__(self, model, num_classes, classes_weight = None, lr = 0.0001, monitor = "val_f1_score"):
super(Net,self).__init__()
self.model = model
self.lr = lr
self.monitor = monitor
self.num_classes = num_classes
self.criterion = nn.CrossEntropyLoss(weight = classes_weight);
self.train_f1 = F1(num_classes = num_classes)
self.val_f1 = F1(num_classes = num_classes)
self.save_hyperparameters()
def forward(self, input_seq):
return self.model(input_seq)
def configure_optimizers(self):
optimizer = torch.optim.Adamax(self.parameters(), lr=self.lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience = 3, verbose = True)
return {
"optimizer":optimizer,
"scheduler":scheduler,
"monitor":self.monitor,
}
def training_step(self, batch, batch_idx):
y_pred = self(batch["data"])
loss = self.criterion(y_pred, batch["label"])
self.log('train_loss', loss, on_step = False, on_epoch = True, prog_bar = False, logger = True)
self.train_f1(y_pred, batch["label"])
self.log('train_f1_score', self.train_f1, on_step = False, on_epoch = True, prog_bar = True, logger = True)
return loss
def validation_step(self, batch, batch_idx):
y_pred = self(batch["data"])
loss = self.criterion(y_pred, batch["label"])
self.log('val_loss', loss, on_step = False, on_epoch = True, prog_bar = False, logger = True)
self.val_f1(y_pred, batch["label"])
self.log('val_f1_score', self.val_f1, on_step = False, on_epoch = True, prog_bar = True, logger = True)
return loss
@staticmethod
def add_model_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--lr',type = float,default = 0.0001)
parser.add_argument('--monitor',type = str,default = "val_f1_score")
return parser