-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathsoft_embedding.py
114 lines (104 loc) · 4.87 KB
/
soft_embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import numpy as np
import torch
import torch.nn as nn
class SoftEmbedding(nn.Module):
def __init__(self,
wte: nn.Embedding,
pt_id: int,
n_tokens: int,
random_range: float = 0.5,
initialize_from_vocab: bool = False,
load_from_path=None):
"""appends learned embedding
Args:
wte (nn.Embedding): original transformer word embedding
pt_id: the prompt token id.
n_tokens (int, optional): number of tokens for task. Defaults to 10.
random_range (float, optional): range to init embedding (if not initialize from vocab). Defaults to 0.5.
initialize_from_vocab (bool, optional): initalizes from default vocab. Defaults to True.
"""
super(SoftEmbedding, self).__init__()
self.wte = wte
self.pt_id = pt_id
self.n_tokens = n_tokens
if not(load_from_path is None):
self.learned_embedding = torch.load(load_from_path)
else:
self.learned_embedding = nn.parameter.Parameter(self.initialize_embedding(wte,
n_tokens,
random_range,
initialize_from_vocab))
def initialize_embedding(self,
wte: nn.Embedding,
n_tokens: int = 10,
random_range: float = 0.5,
initialize_from_vocab: bool = False,
):
"""initializes learned embedding
Args:
same as __init__
Returns:
torch.float: initialized using original schemes
"""
if initialize_from_vocab:
# np.random.seed(0)
tokens_ids = np.random.choice(np.arange(self.wte.weight.shape[0]), n_tokens, replace=False)
return self.wte.weight[tokens_ids].clone().detach() # TODO: change initialization scheme (include specific words?)
return torch.FloatTensor(n_tokens, wte.weight.size(1)).uniform_(-random_range, random_range)
def forward(self, tokens):
"""run forward pass
Args:
tokens (torch.long): input tokens before encoding
Returns:
torch.float: encoding of text concatenated with learned task specific embedding
"""
if tokens.shape[1] == 1: # this is a very SPECIAL DESIGN. Not recommended in general.
return self.wte(tokens)
assert tokens[0, 0] != self.pt_id
assert tokens[0, 1] == self.pt_id
# <s> before <pt>
learned_embeddings_repeated = self.learned_embedding.repeat(tokens.size(0), 1, 1)
assert learned_embeddings_repeated.shape[0] == tokens.shape[0] # [bs, seqlen, embedsize]
assert learned_embeddings_repeated.shape[1] == self.n_tokens
assert tokens.shape[1] > self.n_tokens
input_embedding_left = self.wte(tokens[:, :1])
input_embedding_right = self.wte(tokens[:, self.n_tokens+1:])
return_embeds = torch.cat([input_embedding_left, learned_embeddings_repeated, input_embedding_right], 1)
return return_embeds
class LEmbedding(nn.Module):
def __init__(self,
wte: nn.Embedding,
pt_id: int,
n_tokens: int,
):
"""appends learned embedding
Args:
wte (nn.Embedding): original transformer word embedding
pt_id: the prompt token id.
n_tokens (int, optional): number of tokens for task. Defaults to 10.
"""
super(LEmbedding, self).__init__()
self.wte = wte
self.pt_id = pt_id
self.n_tokens = n_tokens
self.learned_embedding = None
def forward(self, tokens):
"""run forward pass
Args:
tokens (torch.long): input tokens before encoding
Returns:
torch.float: encoding of text concatenated with learned task specific embedding
"""
if tokens.shape[1] == 1: # this is a very SPECIAL DESIGN. Not recommended in general.
return self.wte(tokens)
assert tokens[0, 0] != self.pt_id
assert tokens[0, 1] == self.pt_id
# <s> before <pt>
assert not (self.learned_embedding is None)
assert self.learned_embedding.shape[0] == tokens.shape[0] # [bs, seqlen, embedsize]
assert self.learned_embedding.shape[1] == self.n_tokens
assert tokens.shape[1] > self.n_tokens
input_embedding_left = self.wte(tokens[:, :1])
input_embedding_right = self.wte(tokens[:, self.n_tokens+1:])
return_embeds = torch.cat([input_embedding_left, self.learned_embedding, input_embedding_right], 1)
return return_embeds