Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[s2t] move s2t data preprocess into paddlespeech.dataset #3189

Merged
merged 4 commits into from
Apr 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions examples/aishell/asr1/local/test.sh
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
#!/bin/bash

if [ $# != 3 ];then
echo "usage: ${0} config_path decode_config_path ckpt_path_prefix"
exit -1
fi
set -e

stage=0
stop_stage=100

source utils/parse_options.sh || exit 1;

ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."


if [ $# != 3 ];then
echo "usage: ${0} config_path decode_config_path ckpt_path_prefix"
exit -1
fi

config_path=$1
decode_config_path=$2
ckpt_prefix=$3
Expand Down Expand Up @@ -92,6 +98,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
fi

if [ ${stage} -le 101 ] && [ ${stop_stage} -ge 101 ]; then
echo "using sclite to compute cer..."
# format the reference test file for sclite
python utils/format_rsl.py \
--origin_ref data/manifest.test.raw \
Expand Down
3 changes: 2 additions & 1 deletion paddlespeech/dataset/aidatatang_200zh/aidatatang_200zh.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from paddlespeech.dataset.download import download
from paddlespeech.dataset.download import unpack
from paddlespeech.utils.argparse import print_arguments

DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')

Expand Down Expand Up @@ -139,7 +140,7 @@ def prepare_dataset(url, md5sum, target_dir, manifest_path, subset):


def main():
print(f"args: {args}")
print_arguments(args, globals())
if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir)

Expand Down
3 changes: 2 additions & 1 deletion paddlespeech/dataset/aishell/aishell.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from paddlespeech.dataset.download import download
from paddlespeech.dataset.download import unpack
from paddlespeech.utils.argparse import print_arguments

DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')

Expand Down Expand Up @@ -205,7 +206,7 @@ def prepare_dataset(url, md5sum, target_dir, manifest_path=None, check=False):


def main():
print(f"args: {args}")
print_arguments(args, globals())
if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir)

Expand Down
20 changes: 20 additions & 0 deletions paddlespeech/dataset/s2t/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# s2t utils binaries.
from .avg_model import main as avg_ckpts_main
from .build_vocab import main as build_vocab_main
from .compute_mean_std import main as compute_mean_std_main
from .compute_wer import main as compute_wer_main
from .format_data import main as format_data_main
from .format_rsl import main as format_rsl_main
125 changes: 125 additions & 0 deletions paddlespeech/dataset/s2t/avg_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import glob
import json
import os

import numpy as np
import paddle


def define_argparse():
parser = argparse.ArgumentParser(description='average model')
parser.add_argument('--dst_model', required=True, help='averaged model')
parser.add_argument(
'--ckpt_dir', required=True, help='ckpt model dir for average')
parser.add_argument(
'--val_best', action="store_true", help='averaged model')
parser.add_argument(
'--num', default=5, type=int, help='nums for averaged model')
parser.add_argument(
'--min_epoch',
default=0,
type=int,
help='min epoch used for averaging model')
parser.add_argument(
'--max_epoch',
default=65536, # Big enough
type=int,
help='max epoch used for averaging model')

args = parser.parse_args()
return args


def average_checkpoints(dst_model="",
ckpt_dir="",
val_best=True,
num=5,
min_epoch=0,
max_epoch=65536):
paddle.set_device('cpu')

val_scores = []
jsons = glob.glob(f'{ckpt_dir}/[!train]*.json')
jsons = sorted(jsons, key=os.path.getmtime, reverse=True)
for y in jsons:
with open(y, 'r') as f:
dic_json = json.load(f)
loss = dic_json['val_loss']
epoch = dic_json['epoch']
if epoch >= min_epoch and epoch <= max_epoch:
val_scores.append((epoch, loss))
assert val_scores, f"Not find any valid checkpoints: {val_scores}"
val_scores = np.array(val_scores)

if val_best:
sort_idx = np.argsort(val_scores[:, 1])
sorted_val_scores = val_scores[sort_idx]
else:
sorted_val_scores = val_scores

beat_val_scores = sorted_val_scores[:num, 1]
selected_epochs = sorted_val_scores[:num, 0].astype(np.int64)
avg_val_score = np.mean(beat_val_scores)
print("selected val scores = " + str(beat_val_scores))
print("selected epochs = " + str(selected_epochs))
print("averaged val score = " + str(avg_val_score))

path_list = [
ckpt_dir + '/{}.pdparams'.format(int(epoch))
for epoch in sorted_val_scores[:num, 0]
]
print(path_list)

avg = None
num = args.num
assert num == len(path_list)
for path in path_list:
print(f'Processing {path}')
states = paddle.load(path)
if avg is None:
avg = states
else:
for k in avg.keys():
avg[k] += states[k]
# average
for k in avg.keys():
if avg[k] is not None:
avg[k] /= num

paddle.save(avg, args.dst_model)
print(f'Saving to {args.dst_model}')

meta_path = os.path.splitext(args.dst_model)[0] + '.avg.json'
with open(meta_path, 'w') as f:
data = json.dumps({
"mode": 'val_best' if args.val_best else 'latest',
"avg_ckpt": args.dst_model,
"val_loss_mean": avg_val_score,
"ckpts": path_list,
"epochs": selected_epochs.tolist(),
"val_losses": beat_val_scores.tolist(),
})
f.write(data + "\n")


def main():
args = define_argparse()
average_checkpoints(args)


if __name__ == '__main__':
main()
166 changes: 166 additions & 0 deletions paddlespeech/dataset/s2t/build_vocab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Build vocabulary from manifest files.
Each item in vocabulary file is a character.
"""
import argparse
import functools
import os
import tempfile
from collections import Counter

import jsonlines

from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.frontend.utility import BLANK
from paddlespeech.s2t.frontend.utility import SOS
from paddlespeech.s2t.frontend.utility import SPACE
from paddlespeech.s2t.frontend.utility import UNK
from paddlespeech.utils.argparse import add_arguments
from paddlespeech.utils.argparse import print_arguments


def count_manifest(counter, text_feature, manifest_path):
manifest_jsons = []
with jsonlines.open(manifest_path, 'r') as reader:
for json_data in reader:
manifest_jsons.append(json_data)

for line_json in manifest_jsons:
if isinstance(line_json['text'], str):
tokens = text_feature.tokenize(
line_json['text'], replace_space=False)

counter.update(tokens)
else:
assert isinstance(line_json['text'], list)
for text in line_json['text']:
tokens = text_feature.tokenize(text, replace_space=False)
counter.update(tokens)


def dump_text_manifest(fileobj, manifest_path, key='text'):
manifest_jsons = []
with jsonlines.open(manifest_path, 'r') as reader:
for json_data in reader:
manifest_jsons.append(json_data)

for line_json in manifest_jsons:
if isinstance(line_json[key], str):
fileobj.write(line_json[key] + "\n")
else:
assert isinstance(line_json[key], list)
for line in line_json[key]:
fileobj.write(line + "\n")


def build_vocab(manifest_paths="",
vocab_path="examples/librispeech/data/vocab.txt",
unit_type="char",
count_threshold=0,
text_keys='text',
spm_mode="unigram",
spm_vocab_size=0,
spm_model_prefix="",
spm_character_coverage=0.9995):
fout = open(vocab_path, 'w', encoding='utf-8')
fout.write(BLANK + "\n") # 0 will be used for "blank" in CTC
fout.write(UNK + '\n') # <unk> must be 1

if unit_type == 'spm':
# tools/spm_train --input=$wave_data/lang_char/input.txt
# --vocab_size=${nbpe} --model_type=${bpemode}
# --model_prefix=${bpemodel} --input_sentence_size=100000000
import sentencepiece as spm

fp = tempfile.NamedTemporaryFile(mode='w', delete=False)
for manifest_path in manifest_paths:
_text_keys = [text_keys] if type(
text_keys) is not list else text_keys
for text_key in _text_keys:
dump_text_manifest(fp, manifest_path, key=text_key)
fp.close()
# train
spm.SentencePieceTrainer.Train(
input=fp.name,
vocab_size=spm_vocab_size,
model_type=spm_mode,
model_prefix=spm_model_prefix,
input_sentence_size=100000000,
character_coverage=spm_character_coverage)
os.unlink(fp.name)

# encode
text_feature = TextFeaturizer(unit_type, "", spm_model_prefix)
counter = Counter()

for manifest_path in manifest_paths:
count_manifest(counter, text_feature, manifest_path)

count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True)
tokens = []
for token, count in count_sorted:
if count < count_threshold:
break
# replace space by `<space>`
token = SPACE if token == ' ' else token
tokens.append(token)

tokens = sorted(tokens)
for token in tokens:
fout.write(token + '\n')

fout.write(SOS + "\n") # <sos/eos>
fout.close()


def define_argparse():
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)

# yapf: disable
add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm")
add_arg('count_threshold', int, 0,
"Truncation threshold for char/word counts.Default 0, no truncate.")
add_arg('vocab_path', str,
'examples/librispeech/data/vocab.txt',
"Filepath to write the vocabulary.")
add_arg('manifest_paths', str,
None,
"Filepaths of manifests for building vocabulary. "
"You can provide multiple manifest files.",
nargs='+',
required=True)
add_arg('text_keys', str,
'text',
"keys of the text in manifest for building vocabulary. "
"You can provide multiple k.",
nargs='+')
# bpe
add_arg('spm_vocab_size', int, 0, "Vocab size for spm.")
add_arg('spm_mode', str, 'unigram', "spm model type, e.g. unigram, spm, char, word. only need when `unit_type` is spm")
add_arg('spm_model_prefix', str, "", "spm_model_%(spm_mode)_%(count_threshold), spm model prefix, only need when `unit_type` is spm")
add_arg('spm_character_coverage', float, 0.9995, "character coverage to determine the minimum symbols")
# yapf: disable

args = parser.parse_args()
return args

def main():
args = define_argparse()
print_arguments(args, globals())
build_vocab(**vars(args))

if __name__ == '__main__':
main()
Loading