-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmodel_ori.py
439 lines (383 loc) · 15.4 KB
/
model_ori.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .modules import (
TransformerLayer,
AxialTransformerLayer,
LearnedPositionalEmbedding,
SinusoidalPositionalEmbedding,
RobertaLMHead,
ESM1bLayerNorm,
ContactPredictionHead,
)
from .axial_attention import RowSelfAttention, ColumnSelfAttention
class ProteinBertModel(nn.Module):
@classmethod
def add_args(cls, parser):
parser.add_argument(
"--num_layers", default=36, type=int, metavar="N", help="number of layers"
)
parser.add_argument(
"--embed_dim", default=1280, type=int, metavar="N", help="embedding dimension"
)
parser.add_argument(
"--logit_bias", action="store_true", help="whether to apply bias to logits"
)
parser.add_argument(
"--ffn_embed_dim",
default=5120,
type=int,
metavar="N",
help="embedding dimension for FFN",
)
parser.add_argument(
"--attention_heads",
default=20,
type=int,
metavar="N",
help="number of attention heads",
)
def __init__(self, args, alphabet):
super().__init__()
self.args = args
self.alphabet_size = len(alphabet)
self.padding_idx = alphabet.padding_idx
self.mask_idx = alphabet.mask_idx
self.cls_idx = alphabet.cls_idx
self.eos_idx = alphabet.eos_idx
self.prepend_bos = alphabet.prepend_bos
self.append_eos = alphabet.append_eos
self.emb_layer_norm_before = getattr(self.args, "emb_layer_norm_before", False)
if self.args.arch == "roberta_large":
self.model_version = "ESM-1b"
self._init_submodules_esm1b()
else:
self.model_version = "ESM-1"
self._init_submodules_esm1()
def _init_submodules_common(self):
self.embed_tokens = nn.Embedding(
self.alphabet_size, self.args.embed_dim, padding_idx=self.padding_idx
)
self.layers = nn.ModuleList(
[
TransformerLayer(
self.args.embed_dim,
self.args.ffn_embed_dim,
self.args.attention_heads,
add_bias_kv=(self.model_version != "ESM-1b"),
use_esm1b_layer_norm=(self.model_version == "ESM-1b"),
)
for _ in range(self.args.layers)
]
)
self.contact_head = ContactPredictionHead(
self.args.layers * self.args.attention_heads,
self.prepend_bos,
self.append_eos,
eos_idx=self.eos_idx,
)
def _init_submodules_esm1b(self):
self._init_submodules_common()
self.embed_scale = 1
self.embed_positions = LearnedPositionalEmbedding(
self.args.max_positions, self.args.embed_dim, self.padding_idx
)
self.emb_layer_norm_before = (
ESM1bLayerNorm(self.args.embed_dim) if self.emb_layer_norm_before else None
)
self.emb_layer_norm_after = ESM1bLayerNorm(self.args.embed_dim)
self.lm_head = RobertaLMHead(
embed_dim=self.args.embed_dim,
output_dim=self.alphabet_size,
weight=self.embed_tokens.weight,
)
def _init_submodules_esm1(self):
self._init_submodules_common()
self.embed_scale = math.sqrt(self.args.embed_dim)
self.embed_positions = SinusoidalPositionalEmbedding(self.args.embed_dim, self.padding_idx)
self.embed_out = nn.Parameter(torch.zeros((self.alphabet_size, self.args.embed_dim)))
self.embed_out_bias = None
if self.args.final_bias:
self.embed_out_bias = nn.Parameter(torch.zeros(self.alphabet_size))
def forward(self, tokens, repr_layers=[], need_head_weights=False, return_contacts=False):
if return_contacts:
need_head_weights = True
assert tokens.ndim == 2
padding_mask = tokens.eq(self.padding_idx) # B, T
x = self.embed_scale * self.embed_tokens(tokens)
if getattr(self.args, "token_dropout", False):
x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0)
# x: B x T x C
mask_ratio_train = 0.15 * 0.8
src_lengths = (~padding_mask).sum(-1)
mask_ratio_observed = (tokens == self.mask_idx).sum(-1).float() / src_lengths
x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]
x = x + self.embed_positions(tokens)
if self.model_version == "ESM-1b":
if self.emb_layer_norm_before:
x = self.emb_layer_norm_before(x)
if padding_mask is not None:
x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
repr_layers = set(repr_layers)
hidden_representations = {}
if 0 in repr_layers:
hidden_representations[0] = x
if need_head_weights:
attn_weights = []
# (B, T, E) => (T, B, E)
x = x.transpose(0, 1)
if not padding_mask.any():
padding_mask = None
for layer_idx, layer in enumerate(self.layers):
x, attn = layer(
x, self_attn_padding_mask=padding_mask, need_head_weights=need_head_weights
)
if (layer_idx + 1) in repr_layers:
hidden_representations[layer_idx + 1] = x.transpose(0, 1)
if need_head_weights:
# (H, B, T, T) => (B, H, T, T)
attn_weights.append(attn.transpose(1, 0))
if self.model_version == "ESM-1b":
x = self.emb_layer_norm_after(x)
x = x.transpose(0, 1) # (T, B, E) => (B, T, E)
# last hidden representation should have layer norm applied
if (layer_idx + 1) in repr_layers:
hidden_representations[layer_idx + 1] = x
x = self.lm_head(x)
else:
x = F.linear(x, self.embed_out, bias=self.embed_out_bias)
x = x.transpose(0, 1) # (T, B, E) => (B, T, E)
result = {"logits": x, "representations": hidden_representations}
if need_head_weights:
# attentions: B x L x H x T x T
attentions = torch.stack(attn_weights, 1)
if self.model_version == "ESM-1":
# ESM-1 models have an additional null-token for attention, which we remove
attentions = attentions[..., :-1]
if padding_mask is not None:
attention_mask = 1 - padding_mask.type_as(attentions)
attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
attentions = attentions * attention_mask[:, None, None, :, :]
result["attentions"] = attentions
if return_contacts:
contacts = self.contact_head(tokens, attentions)
result["contacts"] = contacts
return result
def predict_contacts(self, tokens):
return self(tokens, return_contacts=True)["contacts"]
@property
def num_layers(self):
return self.args.layers
class MSATransformer(nn.Module):
@classmethod
def add_args(cls, parser):
# fmt: off
parser.add_argument(
"--num_layers",
default=12,
type=int,
metavar="N",
help="number of layers"
)
parser.add_argument(
"--embed_dim",
default=768,
type=int,
metavar="N",
help="embedding dimension"
)
parser.add_argument(
"--logit_bias",
action="store_true",
help="whether to apply bias to logits"
)
parser.add_argument(
"--ffn_embed_dim",
default=3072,
type=int,
metavar="N",
help="embedding dimension for FFN",
)
parser.add_argument(
"--attention_heads",
default=12,
type=int,
metavar="N",
help="number of attention heads",
)
parser.add_argument(
"--dropout",
default=0.1,
type=float,
help="Dropout to apply."
)
parser.add_argument(
"--attention_dropout",
default=0.1,
type=float,
help="Dropout to apply."
)
parser.add_argument(
"--activation_dropout",
default=0.1,
type=float,
help="Dropout to apply."
)
parser.add_argument(
"--max_tokens_per_msa",
default=2 ** 14,
type=int,
help=(
"Used during inference to batch attention computations in a single "
"forward pass. This allows increased input sizes with less memory."
),
)
# fmt: on
def __init__(self, args, alphabet):
super().__init__()
self.args = args
self.alphabet_size = len(alphabet)
self.padding_idx = alphabet.padding_idx
self.mask_idx = alphabet.mask_idx
self.cls_idx = alphabet.cls_idx
self.eos_idx = alphabet.eos_idx
self.prepend_bos = alphabet.prepend_bos
self.append_eos = alphabet.append_eos
self.embed_tokens = nn.Embedding(
self.alphabet_size, self.args.embed_dim, padding_idx=self.padding_idx
)
if getattr(self.args, "embed_positions_msa", False):
emb_dim = getattr(self.args, "embed_positions_msa_dim", self.args.embed_dim)
self.msa_position_embedding = nn.Parameter(
0.01 * torch.randn(1, 1024, 1, emb_dim),
requires_grad=True,
)
else:
self.register_parameter("msa_position_embedding", None)
self.dropout_module = nn.Dropout(self.args.dropout)
self.layers = nn.ModuleList(
[
AxialTransformerLayer(
self.args.embed_dim,
self.args.ffn_embed_dim,
self.args.attention_heads,
self.args.dropout,
self.args.attention_dropout,
self.args.activation_dropout,
getattr(self.args, "max_tokens_per_msa", self.args.max_tokens),
)
for _ in range(self.args.layers)
]
)
self.contact_head = ContactPredictionHead(
self.args.layers * self.args.attention_heads,
self.prepend_bos,
self.append_eos,
eos_idx=self.eos_idx,
)
self.embed_positions = LearnedPositionalEmbedding(
self.args.max_positions,
self.args.embed_dim,
self.padding_idx,
)
self.emb_layer_norm_before = ESM1bLayerNorm(self.args.embed_dim)
self.emb_layer_norm_after = ESM1bLayerNorm(self.args.embed_dim)
self.lm_head = RobertaLMHead(
embed_dim=self.args.embed_dim,
output_dim=self.alphabet_size,
weight=self.embed_tokens.weight,
)
def forward(self, tokens, repr_layers=[], need_head_weights=False, return_contacts=False):
if return_contacts:
need_head_weights = True
assert tokens.ndim == 3
batch_size, num_alignments, seqlen = tokens.size()
padding_mask = tokens.eq(self.padding_idx) # B, R, C
if not padding_mask.any():
padding_mask = None
x = self.embed_tokens(tokens)
x += self.embed_positions(tokens.view(batch_size * num_alignments, seqlen)).view(x.size())
if self.msa_position_embedding is not None:
if x.size(1) > 1024:
raise RuntimeError(
"Using model with MSA position embedding trained on maximum MSA "
f"depth of 1024, but received {x.size(1)} alignments."
)
x += self.msa_position_embedding[:, :num_alignments]
x = self.emb_layer_norm_before(x)
x = self.dropout_module(x)
if padding_mask is not None:
x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
repr_layers = set(repr_layers)
hidden_representations = {}
if 0 in repr_layers:
hidden_representations[0] = x
if need_head_weights:
row_attn_weights = []
col_attn_weights = []
# B x R x C x D -> R x C x B x D
x = x.permute(1, 2, 0, 3)
for layer_idx, layer in enumerate(self.layers):
x = layer(
x,
self_attn_padding_mask=padding_mask,
need_head_weights=need_head_weights,
)
if need_head_weights:
x, col_attn, _ = x
if layer_idx == 0:
col_attn_weights = symmetrize(col_attn.permute(2, 0, 1, 3, 4))
else:
col_attn_weights += symmetrize(col_attn.permute(2, 0, 1, 3, 4))
# if need_head_weights:
# x, col_attn, row_attn = x
# # H x C x B x R x R -> B x H x C x R x R
# col_attn_weights.append(col_attn.permute(2, 0, 1, 3, 4))
# # H x B x C x C -> B x H x C x C
# row_attn_weights.append(row_attn.permute(1, 0, 2, 3))
if (layer_idx + 1) in repr_layers:
hidden_representations[layer_idx + 1] = x.permute(2, 0, 1, 3)
x = self.emb_layer_norm_after(x)
x = x.permute(2, 0, 1, 3) # R x C x B x D -> B x R x C x D
# last hidden representation should have layer norm applied
if (layer_idx + 1) in repr_layers:
hidden_representations[layer_idx + 1] = x
x = self.lm_head(x)
result = {"logits": x, "representations": hidden_representations}
if need_head_weights:
# # col_attentions: B x L x H x C x R x R
# col_attentions = torch.stack(col_attn_weights, 1)
# # row_attentions: B x L x H x C x C
# row_attentions = torch.stack(row_attn_weights, 1)
# result["col_attentions"] = col_attentions
# result["row_attentions"] = row_attentions
result["col_attentions"] = col_attn_weights.sum(1)
# if return_contacts:
# contacts = self.contact_head(tokens, row_attentions)
# result["contacts"] = contacts
return result
def predict_contacts(self, tokens):
return self(tokens, return_contacts=True)["contacts"]
@property
def num_layers(self):
return self.args.layers
def max_tokens_per_msa_(self, value: int) -> None:
"""The MSA Transformer automatically batches attention computations when
gradients are disabled to allow you to pass in larger MSAs at test time than
you can fit in GPU memory. By default this occurs when more than 2^14 tokens
are passed in the input MSA. You can set this value to infinity to disable
this behavior.
"""
for module in self.modules():
if isinstance(module, (RowSelfAttention, ColumnSelfAttention)):
module.max_tokens_per_msa = value
def symmetrize(x):
"Make layer symmetric in final two dimensions, used for contact prediction."
# print(x.size())
# exit()
return x + x.transpose(-1, -2)