-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
67 lines (57 loc) · 2.3 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import json
import pickle
import opennre
from opennre import encoder, model, framework
import argparse
import logging
logging.basicConfig(level=logging.INFO)
# Silent unimportant log messages
for logger_name in ['transformers.configuration_utils',
'transformers.modeling_utils',
'transformers.tokenization_utils_base']:
logging.getLogger(logger_name).setLevel(logging.WARNING)
# Training for wiki80 and tacred dataset
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', '-m', type=str, required=True,
help='Full path for saving weights during training')
parser.add_argument('--test_path', '-t', type=str, required=True,
help='Full path to file containing testing data')
parser.add_argument('--relation_path', '-r', type=str, required=True,
help='Full path to json file containing relation to index dict')
parser.add_argument('--max_seq_len', '-l', type=int, default=128,
help='Maximum sequence length of bert model')
parser.add_argument('--batch_size', '-b', type=int, default=64,
help='Batch size for training and testing')
parser.add_argument('--pretrain_path', '-p', type=str,
help='Path to pretrained bert-base model weights')
args = parser.parse_args()
rel2id = json.load(open(args.relation_path, 'r'))
# Define the sentence encoder
sentence_encoder = opennre.encoder.BERTEncoder(
max_length=args.max_seq_len,
pretrain_path=args.pretrain_path,
mask_entity=False
)
# Define the model
model = opennre.model.SoftmaxNN(sentence_encoder, len(rel2id), rel2id)
model.to(torch.device('cuda:0'))
# Define the whole training framework
framework = opennre.framework.SentenceRE(
train_path=args.test_path,
val_path=args.test_path,
test_path=args.test_path,
model=model,
ckpt=args.model_path,
batch_size=args.batch_size,
max_epoch=1,
lr=2e-5,
opt='adamw'
)
framework.load_state_dict(torch.load(args.model_path)['state_dict'])
result = framework.eval_model(framework.test_loader)
# Print the result
print('Accuracy on test set: {}'.format(result['acc']))
print('Micro Precision: {}'.format(result['micro_p']))
print('Micro Recall: {}'.format(result['micro_r']))
print('Micro F1: {}'.format(result['micro_f1']))