forked from msamogh/schema_attention_model
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodels.py
133 lines (104 loc) · 5.19 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import numpy as np
import torch
import torch.nn.functional as F
from collections import defaultdict
from torch import nn
from torch.nn import CrossEntropyLoss, NLLLoss
from torch.nn import Dropout
from transformers import BertConfig, BertModel, BertForMaskedLM
from typing import Any
class ActionBertModel(torch.nn.Module):
def __init__(self,
model_name_or_path,
dropout,
num_action_labels):
super(ActionBertModel, self).__init__()
self.bert_model = BertModel.from_pretrained(model_name_or_path)
self.dropout = Dropout(dropout)
self.num_action_labels = num_action_labels
self.action_classifier = nn.Linear(self.bert_model.config.hidden_size, num_action_labels)
def forward(self,
input_ids,
attention_mask,
token_type_ids,
action_label=None):
pooled_output = self.bert_model(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
return_dict=False)[1]
action_logits = self.action_classifier(self.dropout(pooled_output))
# Compute losses if labels provided
if action_label is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(action_logits.view(-1, self.num_action_labels), action_label.type(torch.long))
else:
loss = torch.tensor(0)
return action_logits, loss
class SchemaActionBertModel(torch.nn.Module):
def __init__(self,
model_name_or_path,
dropout,
num_action_labels):
super(SchemaActionBertModel, self).__init__()
self.bert_model = BertModel.from_pretrained(model_name_or_path)
self.dropout = Dropout(dropout)
self.num_action_labels = num_action_labels
self.p_schema = nn.Linear(self.bert_model.config.hidden_size, 1)
def forward(self,
input_ids,
attention_mask,
token_type_ids,
tasks,
action_label,
sc_input_ids,
sc_attention_mask,
sc_token_type_ids,
sc_tasks,
sc_action_label):
all_output, pooled_output = self.bert_model(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
return_dict=False)
print(f"pooled_output: {pooled_output}")
sc_all_output, sc_pooled_output = self.bert_model(input_ids=sc_input_ids,
attention_mask=sc_attention_mask,
token_type_ids=sc_token_type_ids,
return_dict=False)
all_output_flat = all_output.view(-1, all_output.size(-1))
i_probs = F.softmax(all_output_flat.mm(sc_all_output.view(-1, 768).t()), dim=-1).view(all_output_flat.size(0), -1, sc_input_ids.size(-1)).sum(dim=-1)
probs = i_probs.view(input_ids.size(0), -1, i_probs.size(-1)).mean(dim=1)
action_probs = torch.zeros(probs.size(0), self.num_action_labels).cuda().scatter_add(-1, sc_action_label.unsqueeze(0).repeat(probs.size(0), 1), probs)
sc_prob = F.sigmoid(self.p_schema(pooled_output))
action_lps = torch.log(action_probs+1e-10)
# Compute losses if labels provided
if action_label is not None:
loss_fct = NLLLoss()
loss = loss_fct(action_lps.view(-1, self.num_action_labels), action_label.type(torch.long))
else:
loss = torch.tensor(0)
return action_lps, loss
def predict(self,
input_ids,
attention_mask,
token_type_ids,
tasks,
sc_all_output,
sc_pooled_output,
sc_tasks,
sc_action_label):
all_output, pooled_output = self.bert_model(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
return_dict=False)
all_output_flat = all_output.view(-1, all_output.size(-1))
i_probs = F.softmax(all_output_flat.mm(sc_all_output.view(-1, 768).t()), dim=-1).view(all_output_flat.size(0), -1, sc_all_output.size(-2)).sum(dim=-1)
probs = i_probs.view(input_ids.size(0), -1, i_probs.size(-1)).mean(dim=1)
# Zero out any attention across different tasks
for i in range(probs.size(0)):
for j in range(probs.size(1)):
if tasks[i] != sc_tasks[j]:
probs[i,j] = 0
action_probs = torch.zeros(probs.size(0), self.num_action_labels).cuda().scatter_add(-1, sc_action_label.unsqueeze(0).repeat(probs.size(0), 1), probs)
sc_prob = F.sigmoid(self.p_schema(pooled_output))
action_lps = torch.log(action_probs*sc_prob)
return action_lps, 0