forked from fishaudio/Bert-VITS2
-
Notifications
You must be signed in to change notification settings - Fork 107
/
Copy pathpreprocess_text.py
264 lines (225 loc) · 9.32 KB
/
preprocess_text.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
import argparse
import json
from collections import defaultdict
from pathlib import Path
from random import sample, shuffle
from typing import Optional
from tqdm import tqdm
from config import get_config
from style_bert_vits2.logging import logger
from style_bert_vits2.nlp import clean_text
from style_bert_vits2.nlp.japanese import pyopenjtalk_worker
from style_bert_vits2.nlp.japanese.user_dict import update_dict
from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT
# このプロセスからはワーカーを起動して辞書を使いたいので、ここで初期化
pyopenjtalk_worker.initialize_worker()
# dict_data/ 以下の辞書データを pyopenjtalk に適用
update_dict()
preprocess_text_config = get_config().preprocess_text_config
# Count lines for tqdm
def count_lines(file_path: Path):
with file_path.open("r", encoding="utf-8") as file:
return sum(1 for _ in file)
def write_error_log(error_log_path: Path, line: str, error: Exception):
with error_log_path.open("a", encoding="utf-8") as error_log:
error_log.write(f"{line.strip()}\n{error}\n\n")
def process_line(
line: str,
transcription_path: Path,
correct_path: bool,
use_jp_extra: bool,
yomi_error: str,
):
splitted_line = line.strip().split("|")
if len(splitted_line) != 4:
raise ValueError(f"Invalid line format: {line.strip()}")
utt, spk, language, text = splitted_line
norm_text, phones, tones, word2ph = clean_text(
text=text,
language=language, # type: ignore
use_jp_extra=use_jp_extra,
raise_yomi_error=(yomi_error != "use"),
)
if correct_path:
utt = str(transcription_path.parent / "wavs" / utt)
return "{}|{}|{}|{}|{}|{}|{}\n".format(
utt,
spk,
language,
norm_text,
" ".join(phones),
" ".join([str(i) for i in tones]),
" ".join([str(i) for i in word2ph]),
)
def preprocess(
transcription_path: Path,
cleaned_path: Optional[Path],
train_path: Path,
val_path: Path,
config_path: Path,
val_per_lang: int,
max_val_total: int,
# clean: bool,
use_jp_extra: bool,
yomi_error: str,
correct_path: bool,
):
assert yomi_error in ["raise", "skip", "use"]
if cleaned_path == "" or cleaned_path is None:
cleaned_path = transcription_path.with_name(
transcription_path.name + ".cleaned"
)
error_log_path = transcription_path.parent / "text_error.log"
if error_log_path.exists():
error_log_path.unlink()
error_count = 0
total_lines = count_lines(transcription_path)
# transcription_path から 1行ずつ読み込んで文章処理して cleaned_path に書き込む
with (
transcription_path.open("r", encoding="utf-8") as trans_file,
cleaned_path.open("w", encoding="utf-8") as out_file,
):
for line in tqdm(trans_file, file=SAFE_STDOUT, total=total_lines):
try:
processed_line = process_line(
line,
transcription_path,
correct_path,
use_jp_extra,
yomi_error,
)
out_file.write(processed_line)
except Exception as e:
logger.error(
f"An error occurred at line:\n{line.strip()}\n{e}", encoding="utf-8"
)
write_error_log(error_log_path, line, e)
error_count += 1
transcription_path = cleaned_path
# 各話者ごとのlineの辞書
spk_utt_map: dict[str, list[str]] = defaultdict(list)
# 話者からIDへの写像
spk_id_map: dict[str, int] = {}
# 話者ID
current_sid: int = 0
# 音源ファイルのチェックや、spk_id_mapの作成
with transcription_path.open("r", encoding="utf-8") as f:
audio_paths: set[str] = set()
count_same = 0
count_not_found = 0
for line in f.readlines():
utt, spk = line.strip().split("|")[:2]
if utt in audio_paths:
logger.warning(f"Same audio file appears multiple times: {utt}")
count_same += 1
continue
if not Path(utt).is_file():
logger.warning(f"Audio not found: {utt}")
count_not_found += 1
continue
audio_paths.add(utt)
spk_utt_map[spk].append(line)
# 新しい話者が出てきたら話者IDを割り当て、current_sidを1増やす
if spk not in spk_id_map:
spk_id_map[spk] = current_sid
current_sid += 1
if count_same > 0 or count_not_found > 0:
logger.warning(
f"Total repeated audios: {count_same}, Total number of audio not found: {count_not_found}"
)
train_list: list[str] = []
val_list: list[str] = []
# 各話者ごとに発話リストを処理
for spk, utts in spk_utt_map.items():
if val_per_lang == 0:
train_list.extend(utts)
continue
# ランダムにval_per_lang個のインデックスを選択
val_indices = set(sample(range(len(utts)), val_per_lang))
# 元の順序を保ちながらリストを分割
for index, utt in enumerate(utts):
if index in val_indices:
val_list.append(utt)
else:
train_list.append(utt)
# バリデーションリストのサイズ調整
if len(val_list) > max_val_total:
extra_val = val_list[max_val_total:]
val_list = val_list[:max_val_total]
# 余剰のバリデーション発話をトレーニングリストに追加(元の順序を保持)
train_list.extend(extra_val)
with train_path.open("w", encoding="utf-8") as f:
for line in train_list:
f.write(line)
with val_path.open("w", encoding="utf-8") as f:
for line in val_list:
f.write(line)
with config_path.open("r", encoding="utf-8") as f:
json_config = json.load(f)
json_config["data"]["spk2id"] = spk_id_map
json_config["data"]["n_speakers"] = len(spk_id_map)
with config_path.open("w", encoding="utf-8") as f:
json.dump(json_config, f, indent=2, ensure_ascii=False)
if error_count > 0:
if yomi_error == "skip":
logger.warning(
f"An error occurred in {error_count} lines. Proceed with lines without errors. Please check {error_log_path} for details."
)
else:
# yom_error == "raise"と"use"の場合。
# "use"の場合は、そもそもyomi_error = Falseで処理しているので、
# ここが実行されるのは他の例外のときなので、エラーをraiseする。
logger.error(
f"An error occurred in {error_count} lines. Please check {error_log_path} for details."
)
raise Exception(
f"An error occurred in {error_count} lines. Please check `Data/you_model_name/text_error.log` file for details."
)
# 何故か{error_log_path}をraiseすると文字コードエラーが起きるので上のように書いている
else:
logger.info(
"Training set and validation set generation from texts is complete!"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--transcription-path", default=preprocess_text_config.transcription_path
)
parser.add_argument("--cleaned-path", default=preprocess_text_config.cleaned_path)
parser.add_argument("--train-path", default=preprocess_text_config.train_path)
parser.add_argument("--val-path", default=preprocess_text_config.val_path)
parser.add_argument("--config-path", default=preprocess_text_config.config_path)
# 「話者ごと」のバリデーションデータ数、言語ごとではない!
# 元のコードや設定ファイルでval_per_langとなっていたので名前をそのままにしている
parser.add_argument(
"--val-per-lang",
default=preprocess_text_config.val_per_lang,
help="Number of validation data per SPEAKER, not per language (due to compatibility with the original code).",
)
parser.add_argument("--max-val-total", default=preprocess_text_config.max_val_total)
parser.add_argument("--use_jp_extra", action="store_true")
parser.add_argument("--yomi_error", default="raise")
parser.add_argument("--correct_path", action="store_true")
args = parser.parse_args()
transcription_path = Path(args.transcription_path)
cleaned_path = Path(args.cleaned_path) if args.cleaned_path else None
train_path = Path(args.train_path)
val_path = Path(args.val_path)
config_path = Path(args.config_path)
val_per_lang = int(args.val_per_lang)
max_val_total = int(args.max_val_total)
use_jp_extra: bool = args.use_jp_extra
yomi_error: str = args.yomi_error
correct_path: bool = args.correct_path
preprocess(
transcription_path=transcription_path,
cleaned_path=cleaned_path,
train_path=train_path,
val_path=val_path,
config_path=config_path,
val_per_lang=val_per_lang,
max_val_total=max_val_total,
use_jp_extra=use_jp_extra,
yomi_error=yomi_error,
correct_path=correct_path,
)