-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
70 lines (52 loc) · 2.73 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import argparse
import sys
sys.setrecursionlimit(1500)
import nltk
nltk.download('punkt')
from dataloader import Dataloader
from models import SentenceTaggerRNN
from train_model import *
from utils import *
def main(num_samples,
batch_size,
num_epochs,
clip,
seed):
seed_everything(seed)
train = clean_dataset('dataset/cnn_dailymail/train.csv')
val = clean_dataset('dataset/cnn_dailymail/validation.csv')
test = clean_dataset('dataset/cnn_dailymail/test.csv')
train_records = read_records(train, shuffle=True)
val_records = read_records(val, shuffle=False)
test_records = read_records(test, shuffle=False)
log('Training the BPE tokenizer...')
train_bpe(train_records, "BPE_model.bin")
log('Done')
bpe_tokenizer = youtokentome.BPE('BPE_model.bin')
vocabulary = bpe_tokenizer.vocab()
#Cache oracle summary to RAM
ext_train_records = add_oracle_summary_to_records(train_records, nrows=num_samples)
ext_val_records = add_oracle_summary_to_records(val_records, nrows=num_samples)
ext_test_records = add_oracle_summary_to_records(test_records, nrows=num_samples)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_iterator = Dataloader(ext_train_records, vocabulary, batch_size, bpe_tokenizer, device=device)
val_iterator = Dataloader(ext_val_records, vocabulary, batch_size, bpe_tokenizer, device=device)
test_iterator = Dataloader(ext_test_records, vocabulary, batch_size, bpe_tokenizer, device=device)
vocab_size = len(vocabulary)
model = SentenceTaggerRNN(vocab_size).to(device)
params_count = np.sum([p.numel() for p in model.parameters() if p.requires_grad])
log("Trainable params: {}".format(params_count))
log('Training the model...')
training(model, ext_train_records, train_iterator, val_iterator, device, num_epochs, clip, use_class_weights=True)
assert len(ext_train_records) >= len(test_iterator) * batch_size, "Not enough examples"
log('Evaluating the model quality on the test data...')
inference_summarunner_without_novelty_with_class_weights = inference_summarunner(model, test_iterator, top_k=3, threshold=None)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--num_samples', type=int, default=1000, required=True, help='Number of examples')
parser.add_argument('--batch_size', type=int, default=50, help='Batch size')
parser.add_argument('--num_epochs', type=int, default=10, help='Number of epochs')
parser.add_argument('--clip', type=float, default=1.0, help='Value for gradient clipping')
parser.add_argument('--seed', type=int, default=42, help='A seed value')
args = parser.parse_args()
main(**vars(args))