-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathFNDetectionModel.py
53 lines (41 loc) · 1.69 KB
/
FNDetectionModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import numpy as np
from kobert import get_tokenizer
from kobert import get_pytorch_kobert_model
from transformers import AdamW
from transformers.optimization import get_cosine_schedule_with_warmup
import pandas as pd
class BERTClassifier(nn.Module):
def __init__(self,
bert,
hidden_size = 768,
num_classes=2,
dr_rate=None,
params=None):
super(BERTClassifier, self).__init__()
self.bert = bert
self.dr_rate = dr_rate
self.classifier = nn.Linear(hidden_size , num_classes)
if dr_rate:
self.dropout = nn.Dropout(p=dr_rate)
def gen_attention_mask(self, token_ids, valid_length):
attention_mask = torch.zeros_like(token_ids)
for i, v in enumerate(valid_length):
attention_mask[i][:v] = 1
return attention_mask.float()
def forward(self, token_ids, valid_length, segment_ids):
attention_mask = self.gen_attention_mask(token_ids, valid_length)
wordsvec, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device))
#bert 에서 pooler는 첮 cls 레이블의 값을 return 해주는 놈임
fc_layer_weights = self.classifier.weight.data
elemwise_out = pooler * fc_layer_weights
if self.dr_rate:
out = self.dropout(pooler)
else:
out = pooler
return elemwise_out, self.classifier(out)