-
Notifications
You must be signed in to change notification settings - Fork 108
/
training_bartseq2seq_zh_demo.py
120 lines (104 loc) · 4.77 KB
/
training_bartseq2seq_zh_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# -*- coding: utf-8 -*-
"""
@author:XuMing(xuming624@qq.com)
@description:
"""
import argparse
import pandas as pd
from loguru import logger
import os
from transformers import BertTokenizerFast
from transformers import BartForConditionalGeneration, Text2TextGenerationPipeline
import sys
sys.path.append('../..')
from textgen.seq2seq import BartSeq2SeqModel
def model_fill_mask():
tokenizer = BertTokenizerFast.from_pretrained("fnlp/bart-base-chinese")
model = BartForConditionalGeneration.from_pretrained("fnlp/bart-base-chinese")
text2text_generator = Text2TextGenerationPipeline(model, tokenizer)
r = text2text_generator("北京是[MASK]的首都", max_length=50, do_sample=False)
print(r)
def load_data(file_path):
data = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip('\n')
terms = line.split('\t')
if len(terms) == 2:
data.append([terms[0], terms[1]])
else:
logger.warning(f'line error: {line}, split size: {len(terms)}')
return data
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--train_file', default='../data/zh_dialog.tsv', type=str, help='Training data file')
parser.add_argument('--model_type', default='bart', type=str, help='Transformers model type')
parser.add_argument('--model_name', default='fnlp/bart-base-chinese', type=str, help='Transformers model or path')
parser.add_argument('--do_train', action='store_true', help='Whether to run training.')
parser.add_argument('--do_predict', action='store_true', help='Whether to run predict.')
parser.add_argument('--output_dir', default='./outputs/bart_zh/', type=str, help='Model output directory')
parser.add_argument('--max_seq_length', default=50, type=int, help='Input max sequence length')
parser.add_argument('--max_length', default=50, type=int, help='Output max sequence length')
parser.add_argument('--num_epochs', default=30, type=int, help='Number of training epochs')
parser.add_argument('--batch_size', default=32, type=int, help='Batch size')
args = parser.parse_args()
logger.info(args)
if args.do_train:
logger.info('Loading data...')
train_data = load_data(args.train_file)
logger.debug('train_data: {}'.format(train_data[:10]))
train_df = pd.DataFrame(train_data, columns=["input_text", "target_text"])
eval_data = load_data(args.train_file)[:10]
eval_df = pd.DataFrame(eval_data, columns=["input_text", "target_text"])
model_args = {
"reprocess_input_data": True,
"overwrite_output_dir": True,
"max_seq_length": args.max_seq_length,
"max_length": args.max_length,
"train_batch_size": args.batch_size,
"num_train_epochs": args.num_epochs,
"save_eval_checkpoints": False,
"save_model_every_epoch": False,
"evaluate_generated_text": True,
"evaluate_during_training": True,
"evaluate_during_training_verbose": True,
"save_best_model": True,
"output_dir": args.output_dir,
"best_model_dir": os.path.join(args.output_dir, "best_model"),
"use_early_stopping": True,
}
model = BartSeq2SeqModel(
encoder_type=args.model_type,
encoder_decoder_type=args.model_type,
encoder_decoder_name=args.model_name,
args=model_args
)
def sim_text_chars(text1, text2):
if not text1 or not text2:
return 0.0
same = set(text1) & set(text2)
m = len(same)
n = len(set(text1)) if len(set(text1)) > len(set(text2)) else len(set(text2))
return m / n
def count_matches(labels, preds):
logger.debug(f"labels: {labels[:10]}")
logger.debug(f"preds: {preds[:10]}")
match = sum([sim_text_chars(label, pred) for label, pred in zip(labels, preds)]) / len(labels)
logger.debug(f"match: {match}")
return match
model.train_model(train_df, eval_data=eval_df, matches=count_matches)
print(model.eval_model(eval_df, matches=count_matches))
if args.do_predict:
# Use fine-tuned model
tokenizer = BertTokenizerFast.from_pretrained(args.output_dir)
model = BartSeq2SeqModel(
encoder_type=args.model_type,
encoder_decoder_type=args.model_type,
encoder_decoder_name=args.output_dir,
tokenizer=tokenizer)
sentences = ["什么是ai", "你是什么类型的计算机", "你知道热力学吗"]
print("inputs:", sentences)
print("outputs:", model.predict(sentences))
if __name__ == '__main__':
main()
model_fill_mask()