Skip to content

Commit

Permalink
Add CTC HLG decoding using OpenFst (#349)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Oct 8, 2023
1 parent c12286f commit 4076024
Show file tree
Hide file tree
Showing 39 changed files with 964 additions and 56 deletions.
45 changes: 45 additions & 0 deletions .github/scripts/test-offline-ctc.sh
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,48 @@ time $EXE \
$repo/test_wavs/8k.wav

rm -rf $repo

log "------------------------------------------------------------"
log "Run Librispeech zipformer CTC H/HL/HLG decoding (English) "
log "------------------------------------------------------------"
repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02
log "Start testing ${repo_url}"
repo=$(basename $repo_url)
log "Download pretrained model and test-data from $repo_url"

GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
pushd $repo
git lfs pull --include "*.onnx"
git lfs pull --include "*.fst"
ls -lh
popd

graphs=(
$repo/H.fst
$repo/HL.fst
$repo/HLG.fst
)

for graph in ${graphs[@]}; do
log "test float32 models with $graph"
time $EXE \
--model-type=zipformer2_ctc \
--ctc.graph=$graph \
--zipformer-ctc-model=$repo/model.onnx \
--tokens=$repo/tokens.txt \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/2.wav

log "test int8 models with $graph"
time $EXE \
--model-type=zipformer2_ctc \
--ctc.graph=$graph \
--zipformer-ctc-model=$repo/model.int8.onnx \
--tokens=$repo/tokens.txt \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/2.wav
done

rm -rf $repo
2 changes: 1 addition & 1 deletion .github/workflows/test-python-online-websocket-server.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ permissions:
jobs:
python_online_websocket_server:
runs-on: ${{ matrix.os }}
name: ${{ matrix.os }} ${{ matrix.python-version }}
name: ${{ matrix.os }} ${{ matrix.python-version }} ${{ matrix.model_type }}
strategy:
fail-fast: false
matrix:
Expand Down
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ if(NOT BUILD_SHARED_LIBS AND CMAKE_SYSTEM_NAME STREQUAL Linux)
endif()

include(kaldi-native-fbank)
include(kaldi-decoder)
include(onnxruntime)

if(SHERPA_ONNX_ENABLE_PORTAUDIO)
Expand Down
48 changes: 48 additions & 0 deletions cmake/eigen.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
function(download_eigen)
include(FetchContent)

set(eigen_URL "https://gitlab.com/libeigen/eigen/-/archive/3.4.0/eigen-3.4.0.tar.gz")
set(eigen_URL2 "https://huggingface.co/csukuangfj/kaldi-hmm-gmm-cmake-deps/resolve/main/eigen-3.4.0.tar.gz")
set(eigen_HASH "SHA256=8586084f71f9bde545ee7fa6d00288b264a2b7ac3607b974e54d13e7162c1c72")

# If you don't have access to the Internet,
# please pre-download eigen
set(possible_file_locations
$ENV{HOME}/Downloads/eigen-3.4.0.tar.gz
${PROJECT_SOURCE_DIR}/eigen-3.4.0.tar.gz
${PROJECT_BINARY_DIR}/eigen-3.4.0.tar.gz
/tmp/eigen-3.4.0.tar.gz
/star-fj/fangjun/download/github/eigen-3.4.0.tar.gz
)

foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(eigen_URL "${f}")
file(TO_CMAKE_PATH "${eigen_URL}" eigen_URL)
message(STATUS "Found local downloaded eigen: ${eigen_URL}")
set(eigen_URL2)
break()
endif()
endforeach()

set(BUILD_TESTING OFF CACHE BOOL "" FORCE)
set(EIGEN_BUILD_DOC OFF CACHE BOOL "" FORCE)

FetchContent_Declare(eigen
URL ${eigen_URL}
URL_HASH ${eigen_HASH}
)

FetchContent_GetProperties(eigen)
if(NOT eigen_POPULATED)
message(STATUS "Downloading eigen from ${eigen_URL}")
FetchContent_Populate(eigen)
endif()
message(STATUS "eigen is downloaded to ${eigen_SOURCE_DIR}")
message(STATUS "eigen's binary dir is ${eigen_BINARY_DIR}")

add_subdirectory(${eigen_SOURCE_DIR} ${eigen_BINARY_DIR} EXCLUDE_FROM_ALL)
endfunction()

download_eigen()

78 changes: 78 additions & 0 deletions cmake/kaldi-decoder.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
function(download_kaldi_decoder)
include(FetchContent)

set(kaldi_decoder_URL "https://github.com/k2-fsa/kaldi-decoder/archive/refs/tags/v0.2.3.tar.gz")
set(kaldi_decoder_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-decoder-0.2.3.tar.gz")
set(kaldi_decoder_HASH "SHA256=98bf445a5b7961ccf3c3522317d900054eaadb6a9cdcf4531e7d9caece94a56d")

set(KALDI_DECODER_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
set(KALDI_DECODER_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
set(KALDIFST_BUILD_PYTHON OFF CACHE BOOL "" FORCE)

# If you don't have access to the Internet,
# please pre-download kaldi-decoder
set(possible_file_locations
$ENV{HOME}/Downloads/kaldi-decoder-0.2.3.tar.gz
${PROJECT_SOURCE_DIR}/kaldi-decoder-0.2.3.tar.gz
${PROJECT_BINARY_DIR}/kaldi-decoder-0.2.3.tar.gz
/tmp/kaldi-decoder-0.2.3.tar.gz
/star-fj/fangjun/download/github/kaldi-decoder-0.2.3.tar.gz
)

foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(kaldi_decoder_URL "${f}")
file(TO_CMAKE_PATH "${kaldi_decoder_URL}" kaldi_decoder_URL)
message(STATUS "Found local downloaded kaldi-decoder: ${kaldi_decoder_URL}")
set(kaldi_decoder_URL2 )
break()
endif()
endforeach()

FetchContent_Declare(kaldi_decoder
URL
${kaldi_decoder_URL}
${kaldi_decoder_URL2}
URL_HASH ${kaldi_decoder_HASH}
)

FetchContent_GetProperties(kaldi_decoder)
if(NOT kaldi_decoder_POPULATED)
message(STATUS "Downloading kaldi-decoder from ${kaldi_decoder_URL}")
FetchContent_Populate(kaldi_decoder)
endif()
message(STATUS "kaldi-decoder is downloaded to ${kaldi_decoder_SOURCE_DIR}")
message(STATUS "kaldi-decoder's binary dir is ${kaldi_decoder_BINARY_DIR}")

include_directories(${kaldi_decoder_SOURCE_DIR})
add_subdirectory(${kaldi_decoder_SOURCE_DIR} ${kaldi_decoder_BINARY_DIR} EXCLUDE_FROM_ALL)

target_include_directories(kaldi-decoder-core
INTERFACE
${kaldi-decoder_SOURCE_DIR}/
)
if(SHERPA_ONNX_ENABLE_PYTHON AND WIN32)
install(TARGETS
kaldi-decoder-core
kaldifst_core
fst
DESTINATION ..)
else()
install(TARGETS
kaldi-decoder-core
kaldifst_core
fst
DESTINATION lib)
endif()

if(WIN32 AND BUILD_SHARED_LIBS)
install(TARGETS
kaldi-decoder-core
kaldifst_core
fst
DESTINATION bin)
endif()
endfunction()

download_kaldi_decoder()

62 changes: 62 additions & 0 deletions cmake/kaldifst.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
function(download_kaldifst)
include(FetchContent)

set(kaldifst_URL "https://github.com/k2-fsa/kaldifst/archive/refs/tags/v1.7.6.tar.gz")
set(kaldifst_URL2 "https://huggingface.co/csukuangfj/kaldi-hmm-gmm-cmake-deps/resolve/main/kaldifst-1.7.6.tar.gz")
set(kaldifst_HASH "SHA256=79280c0bb08b5ed1a2ab7c21320a2b071f1f0eb10d2f047e8d6f027f0d32b4d2")

# If you don't have access to the Internet,
# please pre-download kaldifst
set(possible_file_locations
$ENV{HOME}/Downloads/kaldifst-1.7.6.tar.gz
${PROJECT_SOURCE_DIR}/kaldifst-1.7.6.tar.gz
${PROJECT_BINARY_DIR}/kaldifst-1.7.6.tar.gz
/tmp/kaldifst-1.7.6.tar.gz
/star-fj/fangjun/download/github/kaldifst-1.7.6.tar.gz
)

foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(kaldifst_URL "${f}")
file(TO_CMAKE_PATH "${kaldifst_URL}" kaldifst_URL)
message(STATUS "Found local downloaded kaldifst: ${kaldifst_URL}")
set(kaldifst_URL2)
break()
endif()
endforeach()

set(KALDIFST_BUILD_TESTS OFF CACHE BOOL "" FORCE)
set(KALDIFST_BUILD_PYTHON OFF CACHE BOOL "" FORCE)

FetchContent_Declare(kaldifst
URL ${kaldifst_URL}
URL_HASH ${kaldifst_HASH}
)

FetchContent_GetProperties(kaldifst)
if(NOT kaldifst_POPULATED)
message(STATUS "Downloading kaldifst from ${kaldifst_URL}")
FetchContent_Populate(kaldifst)
endif()
message(STATUS "kaldifst is downloaded to ${kaldifst_SOURCE_DIR}")
message(STATUS "kaldifst's binary dir is ${kaldifst_BINARY_DIR}")

list(APPEND CMAKE_MODULE_PATH ${kaldifst_SOURCE_DIR}/cmake)

add_subdirectory(${kaldifst_SOURCE_DIR} ${kaldifst_BINARY_DIR} EXCLUDE_FROM_ALL)

target_include_directories(kaldifst_core
PUBLIC
${kaldifst_SOURCE_DIR}/
)

target_include_directories(fst
PUBLIC
${openfst_SOURCE_DIR}/src/include
)

set_target_properties(kaldifst_core PROPERTIES OUTPUT_NAME "sherpa-onnx-kaldifst-core")
set_target_properties(fst PROPERTIES OUTPUT_NAME "sherpa-onnx-fst")
endfunction()

download_kaldifst()
2 changes: 1 addition & 1 deletion cmake/sherpa-onnx.pc.in
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ Cflags: -I"${includedir}"
# Note: -lcargs is required only for the following file
# https://github.com/k2-fsa/sherpa-onnx/blob/master/c-api-examples/decode-file-c-api.c
# We add it here so that users don't need to specify -lcargs when compiling decode-file-c-api.c
Libs: -L"${libdir}" -lsherpa-onnx-c-api -lsherpa-onnx-core -lonnxruntime -lkaldi-native-fbank-core -lcargs -Wl,-rpath,${libdir} @SHERPA_ONNX_PKG_CONFIG_EXTRA_LIBS@
Libs: -L"${libdir}" -lsherpa-onnx-c-api -lsherpa-onnx-core -lonnxruntime -lkaldi-decoder-core -lsherpa-onnx-kaldifst-core -lsherpa-onnx-fst -lkaldi-native-fbank-core -lcargs -Wl,-rpath,${libdir} @SHERPA_ONNX_PKG_CONFIG_EXTRA_LIBS@
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
sherpa-onnx-portaudio_static.lib;
sherpa-onnx-c-api.lib;
sherpa-onnx-core.lib;
kaldi-decoder-core.lib;
sherpa-onnx-kaldifst-core.lib;
sherpa-onnx-fst.lib;
kaldi-native-fbank-core.lib;
absl_base.lib;
absl_city.lib;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
sherpa-onnx-portaudio_static.lib;
sherpa-onnx-c-api.lib;
sherpa-onnx-core.lib;
kaldi-decoder-core.lib;
sherpa-onnx-kaldifst-core.lib;
sherpa-onnx-fst.lib;
kaldi-native-fbank-core.lib;
absl_base.lib;
absl_city.lib;
Expand Down
6 changes: 6 additions & 0 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ set(sources
features.cc
file-utils.cc
hypothesis.cc
offline-ctc-fst-decoder-config.cc
offline-ctc-fst-decoder.cc
offline-ctc-greedy-search-decoder.cc
offline-ctc-model.cc
offline-lm-config.cc
Expand All @@ -42,6 +44,8 @@ set(sources
offline-whisper-greedy-search-decoder.cc
offline-whisper-model-config.cc
offline-whisper-model.cc
offline-zipformer-ctc-model-config.cc
offline-zipformer-ctc-model.cc
online-conformer-transducer-model.cc
online-lm-config.cc
online-lm.cc
Expand Down Expand Up @@ -97,6 +101,8 @@ endif()

target_link_libraries(sherpa-onnx-core kaldi-native-fbank-core)

target_link_libraries(sherpa-onnx-core kaldi-decoder-core)

if(BUILD_SHARED_LIBS OR APPLE OR CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64 OR CMAKE_SYSTEM_PROCESSOR STREQUAL arm)
target_link_libraries(sherpa-onnx-core onnxruntime)
else()
Expand Down
32 changes: 32 additions & 0 deletions sherpa-onnx/csrc/offline-ctc-fst-decoder-config.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// sherpa-onnx/csrc/offline-ctc-fst-decoder-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation

#include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h"

#include <sstream>
#include <string>

namespace sherpa_onnx {

std::string OfflineCtcFstDecoderConfig::ToString() const {
std::ostringstream os;

os << "OfflineCtcFstDecoderConfig(";
os << "graph=\"" << graph << "\", ";
os << "max_active=" << max_active << ")";

return os.str();
}

void OfflineCtcFstDecoderConfig::Register(ParseOptions *po) {
std::string prefix = "ctc";
ParseOptions p(prefix, po);

p.Register("graph", &graph, "Path to H.fst, HL.fst, or HLG.fst");

p.Register("max-active", &max_active,
"Decoder max active states. Larger->slower; more accurate");
}

} // namespace sherpa_onnx
31 changes: 31 additions & 0 deletions sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h
//
// Copyright (c) 2023 Xiaomi Corporation

#ifndef SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_
#define SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_

#include <string>

#include "sherpa-onnx/csrc/parse-options.h"

namespace sherpa_onnx {

struct OfflineCtcFstDecoderConfig {
// Path to H.fst, HL.fst or HLG.fst
std::string graph;
int32_t max_active = 3000;

OfflineCtcFstDecoderConfig() = default;

OfflineCtcFstDecoderConfig(const std::string &graph, int32_t max_active)
: graph(graph), max_active(max_active) {}

std::string ToString() const;

void Register(ParseOptions *po);
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_
Loading

0 comments on commit 4076024

Please sign in to comment.