-
Notifications
You must be signed in to change notification settings - Fork 3
/
t.py
32 lines (27 loc) · 1004 Bytes
/
t.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
from reformer_pytorch import ReformerEncDec
DE_SEQ_LEN = 256
EN_SEQ_LEN = 256
enc_dec = ReformerEncDec(
dim = 128,
enc_num_tokens = 20000,
enc_depth = 6,
enc_max_seq_len = DE_SEQ_LEN,
dec_num_tokens = 20000,
dec_depth = 6,
dec_max_seq_len = EN_SEQ_LEN
)
train_seq_in = torch.randint(0, 20000, (1, DE_SEQ_LEN)).long()
train_seq_out = torch.randint(0, 20000, (1, EN_SEQ_LEN)).long()
input_mask = torch.ones(1, DE_SEQ_LEN).bool()
loss = enc_dec(train_seq_in, train_seq_out, return_loss = True, enc_input_mask = input_mask)
loss.backward()
# learn
# evaluate with the following
eval_seq_in = torch.randint(0, 20000, (1, DE_SEQ_LEN)).long()
eval_seq_out_start = torch.tensor([[0.]]).long() # assume 0 is id of start token
print(eval_seq_in)
print(eval_seq_out_start)
samples = enc_dec.generate(eval_seq_in, eval_seq_out_start, seq_len = EN_SEQ_LEN, eos_token = 1) # assume 1 is id of stop token
print(samples)
print(samples.shape) # (1, <= 1024) decode the tokens