-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #139 from kuke/ctc_decoder_deploy
Add optimized decoders for DS2
- Loading branch information
Showing
34 changed files
with
1,555 additions
and
64 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
#!/usr/bin/env bash | ||
set -e | ||
|
||
readonly VERSION="3.8" | ||
readonly VERSION="3.9" | ||
|
||
version=$(clang-format -version) | ||
|
||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
File renamed without changes.
204 changes: 204 additions & 0 deletions
204
deep_speech_2/decoders/swig/ctc_beam_search_decoder.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
#include "ctc_beam_search_decoder.h" | ||
|
||
#include <algorithm> | ||
#include <cmath> | ||
#include <iostream> | ||
#include <limits> | ||
#include <map> | ||
#include <utility> | ||
|
||
#include "ThreadPool.h" | ||
#include "fst/fstlib.h" | ||
|
||
#include "decoder_utils.h" | ||
#include "path_trie.h" | ||
|
||
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>; | ||
|
||
std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( | ||
const std::vector<std::vector<double>> &probs_seq, | ||
const std::vector<std::string> &vocabulary, | ||
size_t beam_size, | ||
double cutoff_prob, | ||
size_t cutoff_top_n, | ||
Scorer *ext_scorer) { | ||
// dimension check | ||
size_t num_time_steps = probs_seq.size(); | ||
for (size_t i = 0; i < num_time_steps; ++i) { | ||
VALID_CHECK_EQ(probs_seq[i].size(), | ||
vocabulary.size() + 1, | ||
"The shape of probs_seq does not match with " | ||
"the shape of the vocabulary"); | ||
} | ||
|
||
// assign blank id | ||
size_t blank_id = vocabulary.size(); | ||
|
||
// assign space id | ||
auto it = std::find(vocabulary.begin(), vocabulary.end(), " "); | ||
int space_id = it - vocabulary.begin(); | ||
// if no space in vocabulary | ||
if ((size_t)space_id >= vocabulary.size()) { | ||
space_id = -2; | ||
} | ||
|
||
// init prefixes' root | ||
PathTrie root; | ||
root.score = root.log_prob_b_prev = 0.0; | ||
std::vector<PathTrie *> prefixes; | ||
prefixes.push_back(&root); | ||
|
||
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { | ||
auto fst_dict = static_cast<fst::StdVectorFst *>(ext_scorer->dictionary); | ||
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true); | ||
root.set_dictionary(dict_ptr); | ||
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT); | ||
root.set_matcher(matcher); | ||
} | ||
|
||
// prefix search over time | ||
for (size_t time_step = 0; time_step < num_time_steps; ++time_step) { | ||
auto &prob = probs_seq[time_step]; | ||
|
||
float min_cutoff = -NUM_FLT_INF; | ||
bool full_beam = false; | ||
if (ext_scorer != nullptr) { | ||
size_t num_prefixes = std::min(prefixes.size(), beam_size); | ||
std::sort( | ||
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare); | ||
min_cutoff = prefixes[num_prefixes - 1]->score + | ||
std::log(prob[blank_id]) - std::max(0.0, ext_scorer->beta); | ||
full_beam = (num_prefixes == beam_size); | ||
} | ||
|
||
std::vector<std::pair<size_t, float>> log_prob_idx = | ||
get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n); | ||
// loop over chars | ||
for (size_t index = 0; index < log_prob_idx.size(); index++) { | ||
auto c = log_prob_idx[index].first; | ||
auto log_prob_c = log_prob_idx[index].second; | ||
|
||
for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) { | ||
auto prefix = prefixes[i]; | ||
if (full_beam && log_prob_c + prefix->score < min_cutoff) { | ||
break; | ||
} | ||
// blank | ||
if (c == blank_id) { | ||
prefix->log_prob_b_cur = | ||
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score); | ||
continue; | ||
} | ||
// repeated character | ||
if (c == prefix->character) { | ||
prefix->log_prob_nb_cur = log_sum_exp( | ||
prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev); | ||
} | ||
// get new prefix | ||
auto prefix_new = prefix->get_path_trie(c); | ||
|
||
if (prefix_new != nullptr) { | ||
float log_p = -NUM_FLT_INF; | ||
|
||
if (c == prefix->character && | ||
prefix->log_prob_b_prev > -NUM_FLT_INF) { | ||
log_p = log_prob_c + prefix->log_prob_b_prev; | ||
} else if (c != prefix->character) { | ||
log_p = log_prob_c + prefix->score; | ||
} | ||
|
||
// language model scoring | ||
if (ext_scorer != nullptr && | ||
(c == space_id || ext_scorer->is_character_based())) { | ||
PathTrie *prefix_toscore = nullptr; | ||
// skip scoring the space | ||
if (ext_scorer->is_character_based()) { | ||
prefix_toscore = prefix_new; | ||
} else { | ||
prefix_toscore = prefix; | ||
} | ||
|
||
double score = 0.0; | ||
std::vector<std::string> ngram; | ||
ngram = ext_scorer->make_ngram(prefix_toscore); | ||
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; | ||
log_p += score; | ||
log_p += ext_scorer->beta; | ||
} | ||
prefix_new->log_prob_nb_cur = | ||
log_sum_exp(prefix_new->log_prob_nb_cur, log_p); | ||
} | ||
} // end of loop over prefix | ||
} // end of loop over vocabulary | ||
|
||
prefixes.clear(); | ||
// update log probs | ||
root.iterate_to_vec(prefixes); | ||
|
||
// only preserve top beam_size prefixes | ||
if (prefixes.size() >= beam_size) { | ||
std::nth_element(prefixes.begin(), | ||
prefixes.begin() + beam_size, | ||
prefixes.end(), | ||
prefix_compare); | ||
for (size_t i = beam_size; i < prefixes.size(); ++i) { | ||
prefixes[i]->remove(); | ||
} | ||
} | ||
} // end of loop over time | ||
|
||
// compute aproximate ctc score as the return score, without affecting the | ||
// return order of decoding result. To delete when decoder gets stable. | ||
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { | ||
double approx_ctc = prefixes[i]->score; | ||
if (ext_scorer != nullptr) { | ||
std::vector<int> output; | ||
prefixes[i]->get_path_vec(output); | ||
auto prefix_length = output.size(); | ||
auto words = ext_scorer->split_labels(output); | ||
// remove word insert | ||
approx_ctc = approx_ctc - prefix_length * ext_scorer->beta; | ||
// remove language model weight: | ||
approx_ctc -= (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha; | ||
} | ||
prefixes[i]->approx_ctc = approx_ctc; | ||
} | ||
|
||
return get_beam_search_result(prefixes, vocabulary, beam_size); | ||
} | ||
|
||
|
||
std::vector<std::vector<std::pair<double, std::string>>> | ||
ctc_beam_search_decoder_batch( | ||
const std::vector<std::vector<std::vector<double>>> &probs_split, | ||
const std::vector<std::string> &vocabulary, | ||
size_t beam_size, | ||
size_t num_processes, | ||
double cutoff_prob, | ||
size_t cutoff_top_n, | ||
Scorer *ext_scorer) { | ||
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!"); | ||
// thread pool | ||
ThreadPool pool(num_processes); | ||
// number of samples | ||
size_t batch_size = probs_split.size(); | ||
|
||
// enqueue the tasks of decoding | ||
std::vector<std::future<std::vector<std::pair<double, std::string>>>> res; | ||
for (size_t i = 0; i < batch_size; ++i) { | ||
res.emplace_back(pool.enqueue(ctc_beam_search_decoder, | ||
probs_split[i], | ||
vocabulary, | ||
beam_size, | ||
cutoff_prob, | ||
cutoff_top_n, | ||
ext_scorer)); | ||
} | ||
|
||
// get decoding results | ||
std::vector<std::vector<std::pair<double, std::string>>> batch_results; | ||
for (size_t i = 0; i < batch_size; ++i) { | ||
batch_results.emplace_back(res[i].get()); | ||
} | ||
return batch_results; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
#ifndef CTC_BEAM_SEARCH_DECODER_H_ | ||
#define CTC_BEAM_SEARCH_DECODER_H_ | ||
|
||
#include <string> | ||
#include <utility> | ||
#include <vector> | ||
|
||
#include "scorer.h" | ||
|
||
/* CTC Beam Search Decoder | ||
* Parameters: | ||
* probs_seq: 2-D vector that each element is a vector of probabilities | ||
* over vocabulary of one time step. | ||
* vocabulary: A vector of vocabulary. | ||
* beam_size: The width of beam search. | ||
* cutoff_prob: Cutoff probability for pruning. | ||
* cutoff_top_n: Cutoff number for pruning. | ||
* ext_scorer: External scorer to evaluate a prefix, which consists of | ||
* n-gram language model scoring and word insertion term. | ||
* Default null, decoding the input sample without scorer. | ||
* Return: | ||
* A vector that each element is a pair of score and decoding result, | ||
* in desending order. | ||
*/ | ||
std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( | ||
const std::vector<std::vector<double>> &probs_seq, | ||
const std::vector<std::string> &vocabulary, | ||
size_t beam_size, | ||
double cutoff_prob = 1.0, | ||
size_t cutoff_top_n = 40, | ||
Scorer *ext_scorer = nullptr); | ||
|
||
/* CTC Beam Search Decoder for batch data | ||
* Parameters: | ||
* probs_seq: 3-D vector that each element is a 2-D vector that can be used | ||
* by ctc_beam_search_decoder(). | ||
* vocabulary: A vector of vocabulary. | ||
* beam_size: The width of beam search. | ||
* num_processes: Number of threads for beam search. | ||
* cutoff_prob: Cutoff probability for pruning. | ||
* cutoff_top_n: Cutoff number for pruning. | ||
* ext_scorer: External scorer to evaluate a prefix, which consists of | ||
* n-gram language model scoring and word insertion term. | ||
* Default null, decoding the input sample without scorer. | ||
* Return: | ||
* A 2-D vector that each element is a vector of beam search decoding | ||
* result for one audio sample. | ||
*/ | ||
std::vector<std::vector<std::pair<double, std::string>>> | ||
ctc_beam_search_decoder_batch( | ||
const std::vector<std::vector<std::vector<double>>> &probs_split, | ||
const std::vector<std::string> &vocabulary, | ||
size_t beam_size, | ||
size_t num_processes, | ||
double cutoff_prob = 1.0, | ||
size_t cutoff_top_n = 40, | ||
Scorer *ext_scorer = nullptr); | ||
|
||
#endif // CTC_BEAM_SEARCH_DECODER_H_ |
Oops, something went wrong.