Skip to content

Commit

Permalink
Support Chinese vits models (#368)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Oct 18, 2023
1 parent 9efe697 commit 1ee79e3
Show file tree
Hide file tree
Showing 16 changed files with 326 additions and 62 deletions.
44 changes: 41 additions & 3 deletions .github/scripts/test-python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ log() {
}

log "Offline TTS test"
# test waves are saved in ./tts
mkdir ./tts

log "vits-ljs test"

wget -qq https://huggingface.co/csukuangfj/vits-ljs/resolve/main/vits-ljs.onnx
wget -qq https://huggingface.co/csukuangfj/vits-ljs/resolve/main/lexicon.txt
Expand All @@ -18,14 +22,48 @@ python3 ./python-api-examples/offline-tts.py \
--vits-model=./vits-ljs.onnx \
--vits-lexicon=./lexicon.txt \
--vits-tokens=./tokens.txt \
--output-filename=./tts.wav \
--output-filename=./tts/vits-ljs.wav \
'liliana, the most beautiful and lovely assistant of our team!'

ls -lh ./tts.wav
file ./tts.wav
ls -lh ./tts

rm -v vits-ljs.onnx ./lexicon.txt ./tokens.txt

log "vits-vctk test"
wget -qq https://huggingface.co/csukuangfj/vits-vctk/resolve/main/vits-vctk.onnx
wget -qq https://huggingface.co/csukuangfj/vits-vctk/resolve/main/lexicon.txt
wget -qq https://huggingface.co/csukuangfj/vits-vctk/resolve/main/tokens.txt

for sid in 0 10 90; do
python3 ./python-api-examples/offline-tts.py \
--vits-model=./vits-vctk.onnx \
--vits-lexicon=./lexicon.txt \
--vits-tokens=./tokens.txt \
--sid=$sid \
--output-filename=./tts/vits-vctk-${sid}.wav \
'liliana, the most beautiful and lovely assistant of our team!'
done

rm -v vits-vctk.onnx ./lexicon.txt ./tokens.txt

log "vits-zh-aishell3"

wget -qq https://huggingface.co/csukuangfj/vits-zh-aishell3/resolve/main/vits-aishell3.onnx
wget -qq https://huggingface.co/csukuangfj/vits-zh-aishell3/resolve/main/lexicon.txt
wget -qq https://huggingface.co/csukuangfj/vits-zh-aishell3/resolve/main/tokens.txt

for sid in 0 10 90; do
python3 ./python-api-examples/offline-tts.py \
--vits-model=./vits-aishell3.onnx \
--vits-lexicon=./lexicon.txt \
--vits-tokens=./tokens.txt \
--sid=$sid \
--output-filename=./tts/vits-aishell3-${sid}.wav \
'林美丽最美丽'
done

rm -v vits-aishell3.onnx ./lexicon.txt ./tokens.txt

mkdir -p /tmp/icefall-models
dir=/tmp/icefall-models

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/run-python-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,4 @@ jobs:
- uses: actions/upload-artifact@v3
with:
name: tts-generated-test-files
path: tts.wav
path: tts
4 changes: 3 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
project(sherpa-onnx)

set(SHERPA_ONNX_VERSION "1.8.1")
set(SHERPA_ONNX_VERSION "1.8.2")

# Disable warning about
#
Expand Down Expand Up @@ -175,6 +175,8 @@ if(SHERPA_ONNX_ENABLE_WEBSOCKET)
include(asio)
endif()

include(utfcpp)

add_subdirectory(sherpa-onnx)

if(SHERPA_ONNX_ENABLE_C_API)
Expand Down
2 changes: 1 addition & 1 deletion cmake/kaldi-decoder.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ function(download_kaldi_decoder)
set(kaldi_decoder_HASH "SHA256=98bf445a5b7961ccf3c3522317d900054eaadb6a9cdcf4531e7d9caece94a56d")

set(KALDI_DECODER_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
set(KALDI_DECODER_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
set(KALDI_DECODER_ENABLE_TESTS OFF CACHE BOOL "" FORCE)
set(KALDIFST_BUILD_PYTHON OFF CACHE BOOL "" FORCE)

# If you don't have access to the Internet,
Expand Down
16 changes: 8 additions & 8 deletions cmake/kaldi-native-fbank.cmake
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
function(download_kaldi_native_fbank)
include(FetchContent)

set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.18.1.tar.gz")
set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.18.1.tar.gz")
set(kaldi_native_fbank_HASH "SHA256=c7676f319fa97e8c8bca6018792de120895dcfe122fa9b4bff00f8f9165348e7")
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.18.5.tar.gz")
set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.18.5.tar.gz")
set(kaldi_native_fbank_HASH "SHA256=dce0cb3bc6fece5d8053d8780cb4ce22da57cb57ebec332641661521a0425283")

set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
Expand All @@ -12,11 +12,11 @@ function(download_kaldi_native_fbank)
# If you don't have access to the Internet,
# please pre-download kaldi-native-fbank
set(possible_file_locations
$ENV{HOME}/Downloads/kaldi-native-fbank-1.18.1.tar.gz
${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.18.1.tar.gz
${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.18.1.tar.gz
/tmp/kaldi-native-fbank-1.18.1.tar.gz
/star-fj/fangjun/download/github/kaldi-native-fbank-1.18.1.tar.gz
$ENV{HOME}/Downloads/kaldi-native-fbank-1.18.5.tar.gz
${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.18.5.tar.gz
${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.18.5.tar.gz
/tmp/kaldi-native-fbank-1.18.5.tar.gz
/star-fj/fangjun/download/github/kaldi-native-fbank-1.18.5.tar.gz
)

foreach(f IN LISTS possible_file_locations)
Expand Down
45 changes: 45 additions & 0 deletions cmake/utfcpp.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
function(download_utfcpp)
include(FetchContent)

set(utfcpp_URL "https://github.com/nemtrif/utfcpp/archive/refs/tags/v3.2.5.tar.gz")
set(utfcpp_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/utfcpp-3.2.5.tar.gz")
set(utfcpp_HASH "SHA256=14fd1b3c466814cb4c40771b7f207b61d2c7a0aa6a5e620ca05c00df27f25afd")

# If you don't have access to the Internet,
# please pre-download utfcpp
set(possible_file_locations
$ENV{HOME}/Downloads/utfcpp-3.2.5.tar.gz
${PROJECT_SOURCE_DIR}/utfcpp-3.2.5.tar.gz
${PROJECT_BINARY_DIR}/utfcpp-3.2.5.tar.gz
/tmp/utfcpp-3.2.5.tar.gz
/star-fj/fangjun/download/github/utfcpp-3.2.5.tar.gz
)

foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(utfcpp_URL "${f}")
file(TO_CMAKE_PATH "${utfcpp_URL}" utfcpp_URL)
message(STATUS "Found local downloaded utfcpp: ${utfcpp_URL}")
set(utfcpp_URL2)
break()
endif()
endforeach()

FetchContent_Declare(utfcpp
URL
${utfcpp_URL}
${utfcpp_URL2}
URL_HASH ${utfcpp_HASH}
)

FetchContent_GetProperties(utfcpp)
if(NOT utfcpp_POPULATED)
message(STATUS "Downloading utfcpp from ${utfcpp_URL}")
FetchContent_Populate(utfcpp)
endif()
message(STATUS "utfcpp is downloaded to ${utfcpp_SOURCE_DIR}")
# add_subdirectory(${utfcpp_SOURCE_DIR} ${utfcpp_BINARY_DIR} EXCLUDE_FROM_ALL)
include_directories(${utfcpp_SOURCE_DIR})
endfunction()

download_utfcpp()
15 changes: 15 additions & 0 deletions python-api-examples/offline-tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,14 @@
--vits-tokens=./tokens.txt \
--output-filename=./generated.wav \
'liliana, the most beautiful and lovely assistant of our team!'
Please see
https://k2-fsa.github.io/sherpa/onnx/tts/index.html
for details.
"""

import argparse
import time

import sherpa_onnx
import soundfile as sf
Expand Down Expand Up @@ -115,7 +120,14 @@ def main():
)
)
tts = sherpa_onnx.OfflineTts(tts_config)

start = time.time()
audio = tts.generate(args.text, sid=args.sid)
end = time.time()
elapsed_seconds = end - start
audio_duration = len(audio.samples) / audio.sample_rate
real_time_factor = elapsed_seconds / audio_duration

sf.write(
args.output_filename,
audio.samples,
Expand All @@ -124,6 +136,9 @@ def main():
)
print(f"Saved to {args.output_filename}")
print(f"The text is '{args.text}'")
print(f"Elapsed seconds: {elapsed_seconds:.3f}")
print(f"Audio duration in seconds: {audio_duration:.3f}")
print(f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}")


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ if(SHERPA_ONNX_ENABLE_TESTS)
stack-test.cc
transpose-test.cc
unbind-test.cc
utfcpp-test.cc
)

function(sherpa_onnx_add_test source)
Expand Down
145 changes: 100 additions & 45 deletions sherpa-onnx/csrc/lexicon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,105 @@ static std::vector<int32_t> ConvertTokensToIds(
}

Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
const std::string &punctuations) {
const std::string &punctuations, const std::string &language) {
InitLanguage(language);
InitTokens(tokens);
InitLexicon(lexicon);
InitPunctuations(punctuations);
}

std::vector<int64_t> Lexicon::ConvertTextToTokenIds(
const std::string &text) const {
switch (language_) {
case Language::kEnglish:
return ConvertTextToTokenIdsEnglish(text);
case Language::kChinese:
return ConvertTextToTokenIdsChinese(text);
default:
SHERPA_ONNX_LOGE("Unknonw language: %d", static_cast<int32_t>(language_));
exit(-1);
}

return {};
}

std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese(
const std::string &text) const {
std::vector<std::string> words = SplitUtf8(text);

std::vector<int64_t> ans;

ans.push_back(token2id_.at("sil"));

for (const auto &w : words) {
if (!word2ids_.count(w)) {
SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str());
continue;
}

const auto &token_ids = word2ids_.at(w);
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
}
ans.push_back(token2id_.at("sil"));
ans.push_back(token2id_.at("eos"));
return ans;
}

std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish(
const std::string &_text) const {
std::string text(_text);
ToLowerCase(&text);

std::vector<std::string> words = SplitUtf8(text);

std::vector<int64_t> ans;
for (const auto &w : words) {
if (punctuations_.count(w)) {
ans.push_back(token2id_.at(w));
continue;
}

if (!word2ids_.count(w)) {
SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str());
continue;
}

const auto &token_ids = word2ids_.at(w);
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
if (blank_ != -1) {
ans.push_back(blank_);
}
}

if (blank_ != -1 && !ans.empty()) {
// remove the last blank
ans.resize(ans.size() - 1);
}

return ans;
}

void Lexicon::InitTokens(const std::string &tokens) {
token2id_ = ReadTokens(tokens);
blank_ = token2id_.at(" ");
if (token2id_.count(" ")) {
blank_ = token2id_.at(" ");
}
}

void Lexicon::InitLanguage(const std::string &_lang) {
std::string lang(_lang);
ToLowerCase(&lang);
if (lang == "english") {
language_ = Language::kEnglish;
} else if (lang == "chinese") {
language_ = Language::kChinese;
} else {
SHERPA_ONNX_LOGE("Unknown language: %s", _lang.c_str());
exit(-1);
}
}

void Lexicon::InitLexicon(const std::string &lexicon) {
std::ifstream is(lexicon);

std::string word;
Expand Down Expand Up @@ -109,55 +205,14 @@ Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
}
word2ids_.insert({std::move(word), std::move(ids)});
}
}

// process punctuations
void Lexicon::InitPunctuations(const std::string &punctuations) {
std::vector<std::string> punctuation_list;
SplitStringToVector(punctuations, " ", false, &punctuation_list);
for (auto &s : punctuation_list) {
punctuations_.insert(std::move(s));
}
}

std::vector<int64_t> Lexicon::ConvertTextToTokenIds(
const std::string &_text) const {
std::string text(_text);
ToLowerCase(&text);

std::vector<std::string> words;
SplitStringToVector(text, " ", false, &words);

std::vector<int64_t> ans;
for (auto w : words) {
std::vector<int64_t> prefix;
while (!w.empty() && punctuations_.count(std::string(1, w[0]))) {
// if w begins with a punctuation
prefix.push_back(token2id_.at(std::string(1, w[0])));
w = std::string(w.begin() + 1, w.end());
}

std::vector<int64_t> suffix;
while (!w.empty() && punctuations_.count(std::string(1, w.back()))) {
suffix.push_back(token2id_.at(std::string(1, w.back())));
w = std::string(w.begin(), w.end() - 1);
}

if (!word2ids_.count(w)) {
SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str());
continue;
}

const auto &token_ids = word2ids_.at(w);
ans.insert(ans.end(), prefix.begin(), prefix.end());
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
ans.insert(ans.end(), suffix.rbegin(), suffix.rend());
ans.push_back(blank_);
}

if (!ans.empty()) {
ans.resize(ans.size() - 1);
}

return ans;
}

} // namespace sherpa_onnx
Loading

0 comments on commit 1ee79e3

Please sign in to comment.