-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathinference.py
106 lines (90 loc) · 3.6 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""
Open-Domain Question Answering 을 수행하는 inference 코드 입니다.
대부분의 로직은 train.py 와 비슷하나 retrieval, predict 부분이 추가되어 있습니다.
"""
import logging
import sys
from typing import Callable, Dict, List, NoReturn, Tuple
import numpy as np
from arguments import (
DataTrainingArguments, inference_args_class, cfg,
model_args, data_args, inference_args)
from datasets import (
DatasetDict,
load_from_disk,
)
from trainer.trainer import QuestionAnsweringTrainer
from transformers import (
AutoTokenizer,
DataCollatorWithPadding,
AutoModelForQuestionAnswering
)
from utils.load_data import MRC_Dataset
from utils.util import run_sparse_retrieval,set_seed
from model.reader import MRCModel
from utils.util import compute_metrics
import torch
logger = logging.getLogger(__name__)
def test():
# 가능한 arguments 들은 ./arguments.py 나 transformer package 안의 src/transformers/training_args.py 에서 확인 가능합니다.
# --help flag 를 실행시켜서 확인할 수 도 있습니다.
# logging 설정
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
# verbosity 설정 : Transformers logger의 정보로 사용합니다 (on main process only)
logger.info("Training/evaluation parameters %s", cfg)
# 모델을 초기화하기 전에 난수를 고정합니다.
set_seed(cfg.train.seed)
datasets = load_from_disk(cfg.data.dataset_name)
print(datasets)
# AutoConfig를 이용하여 pretrained model 과 tokenizer를 불러옵니다.
# argument로 원하는 모델 이름을 설정하면 옵션을 바꿀 수 있습니다.
tokenizer = AutoTokenizer.from_pretrained(
cfg.model.tokenizer_name
if cfg.model.tokenizer_name is not None
else cfg.model.model_name_or_path,
use_fast=True,
)
if cfg.load_last_model:
# model = MRCModel(cfg.model_name_or_path)
# model.load_state_dict(torch.load(cfg.model.load_last_model))
# print(f"model is from {cfg.model.load_last_model}")
model = AutoModelForQuestionAnswering.from_pretrained(
model_args.trained_model_name,
from_tf=bool(".ckpt" in model_args.trained_model_name),
)
# True일 경우 : run passage retrieval
if cfg.test.test_mode:
datasets = run_sparse_retrieval(
cfg,tokenizer.tokenize, datasets
)
eval_dataset = MRC_Dataset(datasets["validation"],tokenizer=tokenizer)
data_collator = DataCollatorWithPadding(
tokenizer, pad_to_multiple_of=8 if cfg.train.fp16 else None
)
print("init trainer...")
# Trainer 초기화
trainer = QuestionAnsweringTrainer(
model=model,
args=inference_args.inference_args,
train_dataset=None,
eval_dataset=eval_dataset,
eval_examples=datasets["validation"],
tokenizer=tokenizer,
data_collator=data_collator,
post_process_function=eval_dataset.post_processing_function,
compute_metrics=compute_metrics,
)
logger.info("*** Evaluate ***")
#### eval dataset & eval example - predictions.json 생성됨
if cfg.test.test_mode:
predictions = trainer.predict(
test_dataset=eval_dataset, test_examples=datasets["validation"]
)
# predictions.json 은 postprocess_qa_predictions() 호출시 이미 저장됩니다.
print(
"No metric can be presented because there is no correct answer given. Job done!"
)