diff --git a/.clang_format.hook b/.clang_format.hook index 40d70f56cf..4cbc972bbd 100755 --- a/.clang_format.hook +++ b/.clang_format.hook @@ -1,7 +1,7 @@ #!/usr/bin/env bash set -e -readonly VERSION="3.8" +readonly VERSION="3.9" version=$(clang-format -version) diff --git a/deep_speech_2/decoders/__init__.py b/deep_speech_2/decoders/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/deep_speech_2/model_utils/decoder.py b/deep_speech_2/decoders/decoders_deprecated.py similarity index 95% rename from deep_speech_2/model_utils/decoder.py rename to deep_speech_2/decoders/decoders_deprecated.py index ffba2731a0..17b28b0d02 100644 --- a/deep_speech_2/model_utils/decoder.py +++ b/deep_speech_2/decoders/decoders_deprecated.py @@ -42,8 +42,8 @@ def ctc_greedy_decoder(probs_seq, vocabulary): def ctc_beam_search_decoder(probs_seq, beam_size, vocabulary, - blank_id, cutoff_prob=1.0, + cutoff_top_n=40, ext_scoring_func=None, nproc=False): """CTC Beam search decoder. @@ -66,8 +66,6 @@ def ctc_beam_search_decoder(probs_seq, :type beam_size: int :param vocabulary: Vocabulary list. :type vocabulary: list - :param blank_id: ID of blank. - :type blank_id: int :param cutoff_prob: Cutoff probability in pruning, default 1.0, no pruning. :type cutoff_prob: float @@ -87,9 +85,8 @@ def ctc_beam_search_decoder(probs_seq, raise ValueError("The shape of prob_seq does not match with the " "shape of the vocabulary.") - # blank_id check - if not blank_id < len(probs_seq[0]): - raise ValueError("blank_id shouldn't be greater than probs dimension") + # blank_id assign + blank_id = len(vocabulary) # If the decoder called in the multiprocesses, then use the global scorer # instantiated in ctc_beam_search_decoder_batch(). @@ -114,7 +111,7 @@ def ctc_beam_search_decoder(probs_seq, prob_idx = list(enumerate(probs_seq[time_step])) cutoff_len = len(prob_idx) #If pruning is enabled - if cutoff_prob < 1.0: + if cutoff_prob < 1.0 or cutoff_top_n < cutoff_len: prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True) cutoff_len, cum_prob = 0, 0.0 for i in xrange(len(prob_idx)): @@ -122,6 +119,7 @@ def ctc_beam_search_decoder(probs_seq, cutoff_len += 1 if cum_prob >= cutoff_prob: break + cutoff_len = min(cutoff_len, cutoff_top_n) prob_idx = prob_idx[0:cutoff_len] for l in prefix_set_prev: @@ -191,9 +189,9 @@ def ctc_beam_search_decoder(probs_seq, def ctc_beam_search_decoder_batch(probs_split, beam_size, vocabulary, - blank_id, num_processes, cutoff_prob=1.0, + cutoff_top_n=40, ext_scoring_func=None): """CTC beam search decoder using multiple processes. @@ -204,8 +202,6 @@ def ctc_beam_search_decoder_batch(probs_split, :type beam_size: int :param vocabulary: Vocabulary list. :type vocabulary: list - :param blank_id: ID of blank. - :type blank_id: int :param num_processes: Number of parallel processes. :type num_processes: int :param cutoff_prob: Cutoff probability in pruning, @@ -232,8 +228,8 @@ def ctc_beam_search_decoder_batch(probs_split, pool = multiprocessing.Pool(processes=num_processes) results = [] for i, probs_list in enumerate(probs_split): - args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob, None, - nproc) + args = (probs_list, beam_size, vocabulary, cutoff_prob, cutoff_top_n, + None, nproc) results.append(pool.apply_async(ctc_beam_search_decoder, args)) pool.close() diff --git a/deep_speech_2/model_utils/lm_scorer.py b/deep_speech_2/decoders/scorer_deprecated.py similarity index 98% rename from deep_speech_2/model_utils/lm_scorer.py rename to deep_speech_2/decoders/scorer_deprecated.py index 463e96d665..c6a661030d 100644 --- a/deep_speech_2/model_utils/lm_scorer.py +++ b/deep_speech_2/decoders/scorer_deprecated.py @@ -8,7 +8,7 @@ import numpy as np -class LmScorer(object): +class Scorer(object): """External scorer to evaluate a prefix or whole sentence in beam search decoding, including the score from n-gram language model and word count. diff --git a/deep_speech_2/decoders/swig/__init__.py b/deep_speech_2/decoders/swig/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/deep_speech_2/deploy/_init_paths.py b/deep_speech_2/decoders/swig/_init_paths.py similarity index 100% rename from deep_speech_2/deploy/_init_paths.py rename to deep_speech_2/decoders/swig/_init_paths.py diff --git a/deep_speech_2/decoders/swig/ctc_beam_search_decoder.cpp b/deep_speech_2/decoders/swig/ctc_beam_search_decoder.cpp new file mode 100644 index 0000000000..624784b05e --- /dev/null +++ b/deep_speech_2/decoders/swig/ctc_beam_search_decoder.cpp @@ -0,0 +1,204 @@ +#include "ctc_beam_search_decoder.h" + +#include +#include +#include +#include +#include +#include + +#include "ThreadPool.h" +#include "fst/fstlib.h" + +#include "decoder_utils.h" +#include "path_trie.h" + +using FSTMATCH = fst::SortedMatcher; + +std::vector> ctc_beam_search_decoder( + const std::vector> &probs_seq, + const std::vector &vocabulary, + size_t beam_size, + double cutoff_prob, + size_t cutoff_top_n, + Scorer *ext_scorer) { + // dimension check + size_t num_time_steps = probs_seq.size(); + for (size_t i = 0; i < num_time_steps; ++i) { + VALID_CHECK_EQ(probs_seq[i].size(), + vocabulary.size() + 1, + "The shape of probs_seq does not match with " + "the shape of the vocabulary"); + } + + // assign blank id + size_t blank_id = vocabulary.size(); + + // assign space id + auto it = std::find(vocabulary.begin(), vocabulary.end(), " "); + int space_id = it - vocabulary.begin(); + // if no space in vocabulary + if ((size_t)space_id >= vocabulary.size()) { + space_id = -2; + } + + // init prefixes' root + PathTrie root; + root.score = root.log_prob_b_prev = 0.0; + std::vector prefixes; + prefixes.push_back(&root); + + if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { + auto fst_dict = static_cast(ext_scorer->dictionary); + fst::StdVectorFst *dict_ptr = fst_dict->Copy(true); + root.set_dictionary(dict_ptr); + auto matcher = std::make_shared(*dict_ptr, fst::MATCH_INPUT); + root.set_matcher(matcher); + } + + // prefix search over time + for (size_t time_step = 0; time_step < num_time_steps; ++time_step) { + auto &prob = probs_seq[time_step]; + + float min_cutoff = -NUM_FLT_INF; + bool full_beam = false; + if (ext_scorer != nullptr) { + size_t num_prefixes = std::min(prefixes.size(), beam_size); + std::sort( + prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare); + min_cutoff = prefixes[num_prefixes - 1]->score + + std::log(prob[blank_id]) - std::max(0.0, ext_scorer->beta); + full_beam = (num_prefixes == beam_size); + } + + std::vector> log_prob_idx = + get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n); + // loop over chars + for (size_t index = 0; index < log_prob_idx.size(); index++) { + auto c = log_prob_idx[index].first; + auto log_prob_c = log_prob_idx[index].second; + + for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) { + auto prefix = prefixes[i]; + if (full_beam && log_prob_c + prefix->score < min_cutoff) { + break; + } + // blank + if (c == blank_id) { + prefix->log_prob_b_cur = + log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score); + continue; + } + // repeated character + if (c == prefix->character) { + prefix->log_prob_nb_cur = log_sum_exp( + prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev); + } + // get new prefix + auto prefix_new = prefix->get_path_trie(c); + + if (prefix_new != nullptr) { + float log_p = -NUM_FLT_INF; + + if (c == prefix->character && + prefix->log_prob_b_prev > -NUM_FLT_INF) { + log_p = log_prob_c + prefix->log_prob_b_prev; + } else if (c != prefix->character) { + log_p = log_prob_c + prefix->score; + } + + // language model scoring + if (ext_scorer != nullptr && + (c == space_id || ext_scorer->is_character_based())) { + PathTrie *prefix_toscore = nullptr; + // skip scoring the space + if (ext_scorer->is_character_based()) { + prefix_toscore = prefix_new; + } else { + prefix_toscore = prefix; + } + + double score = 0.0; + std::vector ngram; + ngram = ext_scorer->make_ngram(prefix_toscore); + score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; + log_p += score; + log_p += ext_scorer->beta; + } + prefix_new->log_prob_nb_cur = + log_sum_exp(prefix_new->log_prob_nb_cur, log_p); + } + } // end of loop over prefix + } // end of loop over vocabulary + + prefixes.clear(); + // update log probs + root.iterate_to_vec(prefixes); + + // only preserve top beam_size prefixes + if (prefixes.size() >= beam_size) { + std::nth_element(prefixes.begin(), + prefixes.begin() + beam_size, + prefixes.end(), + prefix_compare); + for (size_t i = beam_size; i < prefixes.size(); ++i) { + prefixes[i]->remove(); + } + } + } // end of loop over time + + // compute aproximate ctc score as the return score, without affecting the + // return order of decoding result. To delete when decoder gets stable. + for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { + double approx_ctc = prefixes[i]->score; + if (ext_scorer != nullptr) { + std::vector output; + prefixes[i]->get_path_vec(output); + auto prefix_length = output.size(); + auto words = ext_scorer->split_labels(output); + // remove word insert + approx_ctc = approx_ctc - prefix_length * ext_scorer->beta; + // remove language model weight: + approx_ctc -= (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha; + } + prefixes[i]->approx_ctc = approx_ctc; + } + + return get_beam_search_result(prefixes, vocabulary, beam_size); +} + + +std::vector>> +ctc_beam_search_decoder_batch( + const std::vector>> &probs_split, + const std::vector &vocabulary, + size_t beam_size, + size_t num_processes, + double cutoff_prob, + size_t cutoff_top_n, + Scorer *ext_scorer) { + VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!"); + // thread pool + ThreadPool pool(num_processes); + // number of samples + size_t batch_size = probs_split.size(); + + // enqueue the tasks of decoding + std::vector>>> res; + for (size_t i = 0; i < batch_size; ++i) { + res.emplace_back(pool.enqueue(ctc_beam_search_decoder, + probs_split[i], + vocabulary, + beam_size, + cutoff_prob, + cutoff_top_n, + ext_scorer)); + } + + // get decoding results + std::vector>> batch_results; + for (size_t i = 0; i < batch_size; ++i) { + batch_results.emplace_back(res[i].get()); + } + return batch_results; +} diff --git a/deep_speech_2/decoders/swig/ctc_beam_search_decoder.h b/deep_speech_2/decoders/swig/ctc_beam_search_decoder.h new file mode 100644 index 0000000000..6fdd15517e --- /dev/null +++ b/deep_speech_2/decoders/swig/ctc_beam_search_decoder.h @@ -0,0 +1,61 @@ +#ifndef CTC_BEAM_SEARCH_DECODER_H_ +#define CTC_BEAM_SEARCH_DECODER_H_ + +#include +#include +#include + +#include "scorer.h" + +/* CTC Beam Search Decoder + + * Parameters: + * probs_seq: 2-D vector that each element is a vector of probabilities + * over vocabulary of one time step. + * vocabulary: A vector of vocabulary. + * beam_size: The width of beam search. + * cutoff_prob: Cutoff probability for pruning. + * cutoff_top_n: Cutoff number for pruning. + * ext_scorer: External scorer to evaluate a prefix, which consists of + * n-gram language model scoring and word insertion term. + * Default null, decoding the input sample without scorer. + * Return: + * A vector that each element is a pair of score and decoding result, + * in desending order. +*/ +std::vector> ctc_beam_search_decoder( + const std::vector> &probs_seq, + const std::vector &vocabulary, + size_t beam_size, + double cutoff_prob = 1.0, + size_t cutoff_top_n = 40, + Scorer *ext_scorer = nullptr); + +/* CTC Beam Search Decoder for batch data + + * Parameters: + * probs_seq: 3-D vector that each element is a 2-D vector that can be used + * by ctc_beam_search_decoder(). + * vocabulary: A vector of vocabulary. + * beam_size: The width of beam search. + * num_processes: Number of threads for beam search. + * cutoff_prob: Cutoff probability for pruning. + * cutoff_top_n: Cutoff number for pruning. + * ext_scorer: External scorer to evaluate a prefix, which consists of + * n-gram language model scoring and word insertion term. + * Default null, decoding the input sample without scorer. + * Return: + * A 2-D vector that each element is a vector of beam search decoding + * result for one audio sample. +*/ +std::vector>> +ctc_beam_search_decoder_batch( + const std::vector>> &probs_split, + const std::vector &vocabulary, + size_t beam_size, + size_t num_processes, + double cutoff_prob = 1.0, + size_t cutoff_top_n = 40, + Scorer *ext_scorer = nullptr); + +#endif // CTC_BEAM_SEARCH_DECODER_H_ diff --git a/deep_speech_2/decoders/swig/ctc_greedy_decoder.cpp b/deep_speech_2/decoders/swig/ctc_greedy_decoder.cpp new file mode 100644 index 0000000000..03449d7391 --- /dev/null +++ b/deep_speech_2/decoders/swig/ctc_greedy_decoder.cpp @@ -0,0 +1,45 @@ +#include "ctc_greedy_decoder.h" +#include "decoder_utils.h" + +std::string ctc_greedy_decoder( + const std::vector> &probs_seq, + const std::vector &vocabulary) { + // dimension check + size_t num_time_steps = probs_seq.size(); + for (size_t i = 0; i < num_time_steps; ++i) { + VALID_CHECK_EQ(probs_seq[i].size(), + vocabulary.size() + 1, + "The shape of probs_seq does not match with " + "the shape of the vocabulary"); + } + + size_t blank_id = vocabulary.size(); + + std::vector max_idx_vec(num_time_steps, 0); + std::vector idx_vec; + for (size_t i = 0; i < num_time_steps; ++i) { + double max_prob = 0.0; + size_t max_idx = 0; + const std::vector &probs_step = probs_seq[i]; + for (size_t j = 0; j < probs_step.size(); ++j) { + if (max_prob < probs_step[j]) { + max_idx = j; + max_prob = probs_step[j]; + } + } + // id with maximum probability in current time step + max_idx_vec[i] = max_idx; + // deduplicate + if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i - 1])) { + idx_vec.push_back(max_idx_vec[i]); + } + } + + std::string best_path_result; + for (size_t i = 0; i < idx_vec.size(); ++i) { + if (idx_vec[i] != blank_id) { + best_path_result += vocabulary[idx_vec[i]]; + } + } + return best_path_result; +} diff --git a/deep_speech_2/decoders/swig/ctc_greedy_decoder.h b/deep_speech_2/decoders/swig/ctc_greedy_decoder.h new file mode 100644 index 0000000000..5e64f692e5 --- /dev/null +++ b/deep_speech_2/decoders/swig/ctc_greedy_decoder.h @@ -0,0 +1,20 @@ +#ifndef CTC_GREEDY_DECODER_H +#define CTC_GREEDY_DECODER_H + +#include +#include + +/* CTC Greedy (Best Path) Decoder + * + * Parameters: + * probs_seq: 2-D vector that each element is a vector of probabilities + * over vocabulary of one time step. + * vocabulary: A vector of vocabulary. + * Return: + * The decoding result in string + */ +std::string ctc_greedy_decoder( + const std::vector>& probs_seq, + const std::vector& vocabulary); + +#endif // CTC_GREEDY_DECODER_H diff --git a/deep_speech_2/decoders/swig/decoder_utils.cpp b/deep_speech_2/decoders/swig/decoder_utils.cpp new file mode 100644 index 0000000000..70a1592889 --- /dev/null +++ b/deep_speech_2/decoders/swig/decoder_utils.cpp @@ -0,0 +1,176 @@ +#include "decoder_utils.h" + +#include +#include +#include + +std::vector> get_pruned_log_probs( + const std::vector &prob_step, + double cutoff_prob, + size_t cutoff_top_n) { + std::vector> prob_idx; + for (size_t i = 0; i < prob_step.size(); ++i) { + prob_idx.push_back(std::pair(i, prob_step[i])); + } + // pruning of vacobulary + size_t cutoff_len = prob_step.size(); + if (cutoff_prob < 1.0 || cutoff_top_n < cutoff_len) { + std::sort( + prob_idx.begin(), prob_idx.end(), pair_comp_second_rev); + if (cutoff_prob < 1.0) { + double cum_prob = 0.0; + cutoff_len = 0; + for (size_t i = 0; i < prob_idx.size(); ++i) { + cum_prob += prob_idx[i].second; + cutoff_len += 1; + if (cum_prob >= cutoff_prob || cutoff_len >= cutoff_top_n) break; + } + } + prob_idx = std::vector>( + prob_idx.begin(), prob_idx.begin() + cutoff_len); + } + std::vector> log_prob_idx; + for (size_t i = 0; i < cutoff_len; ++i) { + log_prob_idx.push_back(std::pair( + prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN))); + } + return log_prob_idx; +} + + +std::vector> get_beam_search_result( + const std::vector &prefixes, + const std::vector &vocabulary, + size_t beam_size) { + // allow for the post processing + std::vector space_prefixes; + if (space_prefixes.empty()) { + for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { + space_prefixes.push_back(prefixes[i]); + } + } + + std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare); + std::vector> output_vecs; + for (size_t i = 0; i < beam_size && i < space_prefixes.size(); ++i) { + std::vector output; + space_prefixes[i]->get_path_vec(output); + // convert index to string + std::string output_str; + for (size_t j = 0; j < output.size(); j++) { + output_str += vocabulary[output[j]]; + } + std::pair output_pair(-space_prefixes[i]->approx_ctc, + output_str); + output_vecs.emplace_back(output_pair); + } + + return output_vecs; +} + +size_t get_utf8_str_len(const std::string &str) { + size_t str_len = 0; + for (char c : str) { + str_len += ((c & 0xc0) != 0x80); + } + return str_len; +} + +std::vector split_utf8_str(const std::string &str) { + std::vector result; + std::string out_str; + + for (char c : str) { + if ((c & 0xc0) != 0x80) // new UTF-8 character + { + if (!out_str.empty()) { + result.push_back(out_str); + out_str.clear(); + } + } + + out_str.append(1, c); + } + result.push_back(out_str); + return result; +} + +std::vector split_str(const std::string &s, + const std::string &delim) { + std::vector result; + std::size_t start = 0, delim_len = delim.size(); + while (true) { + std::size_t end = s.find(delim, start); + if (end == std::string::npos) { + if (start < s.size()) { + result.push_back(s.substr(start)); + } + break; + } + if (end > start) { + result.push_back(s.substr(start, end - start)); + } + start = end + delim_len; + } + return result; +} + +bool prefix_compare(const PathTrie *x, const PathTrie *y) { + if (x->score == y->score) { + if (x->character == y->character) { + return false; + } else { + return (x->character < y->character); + } + } else { + return x->score > y->score; + } +} + +void add_word_to_fst(const std::vector &word, + fst::StdVectorFst *dictionary) { + if (dictionary->NumStates() == 0) { + fst::StdVectorFst::StateId start = dictionary->AddState(); + assert(start == 0); + dictionary->SetStart(start); + } + fst::StdVectorFst::StateId src = dictionary->Start(); + fst::StdVectorFst::StateId dst; + for (auto c : word) { + dst = dictionary->AddState(); + dictionary->AddArc(src, fst::StdArc(c, c, 0, dst)); + src = dst; + } + dictionary->SetFinal(dst, fst::StdArc::Weight::One()); +} + +bool add_word_to_dictionary( + const std::string &word, + const std::unordered_map &char_map, + bool add_space, + int SPACE_ID, + fst::StdVectorFst *dictionary) { + auto characters = split_utf8_str(word); + + std::vector int_word; + + for (auto &c : characters) { + if (c == " ") { + int_word.push_back(SPACE_ID); + } else { + auto int_c = char_map.find(c); + if (int_c != char_map.end()) { + int_word.push_back(int_c->second); + } else { + return false; // return without adding + } + } + } + + if (add_space) { + int_word.push_back(SPACE_ID); + } + + add_word_to_fst(int_word, dictionary); + return true; // return with successful adding +} diff --git a/deep_speech_2/decoders/swig/decoder_utils.h b/deep_speech_2/decoders/swig/decoder_utils.h new file mode 100644 index 0000000000..72821c187f --- /dev/null +++ b/deep_speech_2/decoders/swig/decoder_utils.h @@ -0,0 +1,94 @@ +#ifndef DECODER_UTILS_H_ +#define DECODER_UTILS_H_ + +#include +#include "fst/log.h" +#include "path_trie.h" + +const float NUM_FLT_INF = std::numeric_limits::max(); +const float NUM_FLT_MIN = std::numeric_limits::min(); + +// inline function for validation check +inline void check( + bool x, const char *expr, const char *file, int line, const char *err) { + if (!x) { + std::cout << "[" << file << ":" << line << "] "; + LOG(FATAL) << "\"" << expr << "\" check failed. " << err; + } +} + +#define VALID_CHECK(x, info) \ + check(static_cast(x), #x, __FILE__, __LINE__, info) +#define VALID_CHECK_EQ(x, y, info) VALID_CHECK((x) == (y), info) +#define VALID_CHECK_GT(x, y, info) VALID_CHECK((x) > (y), info) +#define VALID_CHECK_LT(x, y, info) VALID_CHECK((x) < (y), info) + + +// Function template for comparing two pairs +template +bool pair_comp_first_rev(const std::pair &a, + const std::pair &b) { + return a.first > b.first; +} + +// Function template for comparing two pairs +template +bool pair_comp_second_rev(const std::pair &a, + const std::pair &b) { + return a.second > b.second; +} + +// Return the sum of two probabilities in log scale +template +T log_sum_exp(const T &x, const T &y) { + static T num_min = -std::numeric_limits::max(); + if (x <= num_min) return y; + if (y <= num_min) return x; + T xmax = std::max(x, y); + return std::log(std::exp(x - xmax) + std::exp(y - xmax)) + xmax; +} + +// Get pruned probability vector for each time step's beam search +std::vector> get_pruned_log_probs( + const std::vector &prob_step, + double cutoff_prob, + size_t cutoff_top_n); + +// Get beam search result from prefixes in trie tree +std::vector> get_beam_search_result( + const std::vector &prefixes, + const std::vector &vocabulary, + size_t beam_size); + +// Functor for prefix comparsion +bool prefix_compare(const PathTrie *x, const PathTrie *y); + +/* Get length of utf8 encoding string + * See: http://stackoverflow.com/a/4063229 + */ +size_t get_utf8_str_len(const std::string &str); + +/* Split a string into a list of strings on a given string + * delimiter. NB: delimiters on beginning / end of string are + * trimmed. Eg, "FooBarFoo" split on "Foo" returns ["Bar"]. + */ +std::vector split_str(const std::string &s, + const std::string &delim); + +/* Splits string into vector of strings representing + * UTF-8 characters (not same as chars) + */ +std::vector split_utf8_str(const std::string &str); + +// Add a word in index to the dicionary of fst +void add_word_to_fst(const std::vector &word, + fst::StdVectorFst *dictionary); + +// Add a word in string to dictionary +bool add_word_to_dictionary( + const std::string &word, + const std::unordered_map &char_map, + bool add_space, + int SPACE_ID, + fst::StdVectorFst *dictionary); +#endif // DECODER_UTILS_H diff --git a/deep_speech_2/decoders/swig/decoders.i b/deep_speech_2/decoders/swig/decoders.i new file mode 100644 index 0000000000..4227d4a375 --- /dev/null +++ b/deep_speech_2/decoders/swig/decoders.i @@ -0,0 +1,33 @@ +%module swig_decoders +%{ +#include "scorer.h" +#include "ctc_greedy_decoder.h" +#include "ctc_beam_search_decoder.h" +#include "decoder_utils.h" +%} + +%include "std_vector.i" +%include "std_pair.i" +%include "std_string.i" +%import "decoder_utils.h" + +namespace std { + %template(DoubleVector) std::vector; + %template(IntVector) std::vector; + %template(StringVector) std::vector; + %template(VectorOfStructVector) std::vector >; + %template(FloatVector) std::vector; + %template(Pair) std::pair; + %template(PairFloatStringVector) std::vector >; + %template(PairDoubleStringVector) std::vector >; + %template(PairDoubleStringVector2) std::vector > >; + %template(DoubleVector3) std::vector > >; +} + +%template(IntDoublePairCompSecondRev) pair_comp_second_rev; +%template(StringDoublePairCompSecondRev) pair_comp_second_rev; +%template(DoubleStringPairCompFirstRev) pair_comp_first_rev; + +%include "scorer.h" +%include "ctc_greedy_decoder.h" +%include "ctc_beam_search_decoder.h" diff --git a/deep_speech_2/decoders/swig/path_trie.cpp b/deep_speech_2/decoders/swig/path_trie.cpp new file mode 100644 index 0000000000..40d9097055 --- /dev/null +++ b/deep_speech_2/decoders/swig/path_trie.cpp @@ -0,0 +1,148 @@ +#include "path_trie.h" + +#include +#include +#include +#include +#include + +#include "decoder_utils.h" + +PathTrie::PathTrie() { + log_prob_b_prev = -NUM_FLT_INF; + log_prob_nb_prev = -NUM_FLT_INF; + log_prob_b_cur = -NUM_FLT_INF; + log_prob_nb_cur = -NUM_FLT_INF; + score = -NUM_FLT_INF; + + ROOT_ = -1; + character = ROOT_; + exists_ = true; + parent = nullptr; + + dictionary_ = nullptr; + dictionary_state_ = 0; + has_dictionary_ = false; + + matcher_ = nullptr; +} + +PathTrie::~PathTrie() { + for (auto child : children_) { + delete child.second; + } +} + +PathTrie* PathTrie::get_path_trie(int new_char, bool reset) { + auto child = children_.begin(); + for (child = children_.begin(); child != children_.end(); ++child) { + if (child->first == new_char) { + break; + } + } + if (child != children_.end()) { + if (!child->second->exists_) { + child->second->exists_ = true; + child->second->log_prob_b_prev = -NUM_FLT_INF; + child->second->log_prob_nb_prev = -NUM_FLT_INF; + child->second->log_prob_b_cur = -NUM_FLT_INF; + child->second->log_prob_nb_cur = -NUM_FLT_INF; + } + return (child->second); + } else { + if (has_dictionary_) { + matcher_->SetState(dictionary_state_); + bool found = matcher_->Find(new_char); + if (!found) { + // Adding this character causes word outside dictionary + auto FSTZERO = fst::TropicalWeight::Zero(); + auto final_weight = dictionary_->Final(dictionary_state_); + bool is_final = (final_weight != FSTZERO); + if (is_final && reset) { + dictionary_state_ = dictionary_->Start(); + } + return nullptr; + } else { + PathTrie* new_path = new PathTrie; + new_path->character = new_char; + new_path->parent = this; + new_path->dictionary_ = dictionary_; + new_path->dictionary_state_ = matcher_->Value().nextstate; + new_path->has_dictionary_ = true; + new_path->matcher_ = matcher_; + children_.push_back(std::make_pair(new_char, new_path)); + return new_path; + } + } else { + PathTrie* new_path = new PathTrie; + new_path->character = new_char; + new_path->parent = this; + children_.push_back(std::make_pair(new_char, new_path)); + return new_path; + } + } +} + +PathTrie* PathTrie::get_path_vec(std::vector& output) { + return get_path_vec(output, ROOT_); +} + +PathTrie* PathTrie::get_path_vec(std::vector& output, + int stop, + size_t max_steps) { + if (character == stop || character == ROOT_ || output.size() == max_steps) { + std::reverse(output.begin(), output.end()); + return this; + } else { + output.push_back(character); + return parent->get_path_vec(output, stop, max_steps); + } +} + +void PathTrie::iterate_to_vec(std::vector& output) { + if (exists_) { + log_prob_b_prev = log_prob_b_cur; + log_prob_nb_prev = log_prob_nb_cur; + + log_prob_b_cur = -NUM_FLT_INF; + log_prob_nb_cur = -NUM_FLT_INF; + + score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev); + output.push_back(this); + } + for (auto child : children_) { + child.second->iterate_to_vec(output); + } +} + +void PathTrie::remove() { + exists_ = false; + + if (children_.size() == 0) { + auto child = parent->children_.begin(); + for (child = parent->children_.begin(); child != parent->children_.end(); + ++child) { + if (child->first == character) { + parent->children_.erase(child); + break; + } + } + + if (parent->children_.size() == 0 && !parent->exists_) { + parent->remove(); + } + + delete this; + } +} + +void PathTrie::set_dictionary(fst::StdVectorFst* dictionary) { + dictionary_ = dictionary; + dictionary_state_ = dictionary->Start(); + has_dictionary_ = true; +} + +using FSTMATCH = fst::SortedMatcher; +void PathTrie::set_matcher(std::shared_ptr matcher) { + matcher_ = matcher; +} diff --git a/deep_speech_2/decoders/swig/path_trie.h b/deep_speech_2/decoders/swig/path_trie.h new file mode 100644 index 0000000000..7fd715d26d --- /dev/null +++ b/deep_speech_2/decoders/swig/path_trie.h @@ -0,0 +1,67 @@ +#ifndef PATH_TRIE_H +#define PATH_TRIE_H + +#include +#include +#include +#include +#include + +#include "fst/fstlib.h" + +/* Trie tree for prefix storing and manipulating, with a dictionary in + * finite-state transducer for spelling correction. + */ +class PathTrie { +public: + PathTrie(); + ~PathTrie(); + + // get new prefix after appending new char + PathTrie* get_path_trie(int new_char, bool reset = true); + + // get the prefix in index from root to current node + PathTrie* get_path_vec(std::vector& output); + + // get the prefix in index from some stop node to current nodel + PathTrie* get_path_vec(std::vector& output, + int stop, + size_t max_steps = std::numeric_limits::max()); + + // update log probs + void iterate_to_vec(std::vector& output); + + // set dictionary for FST + void set_dictionary(fst::StdVectorFst* dictionary); + + void set_matcher(std::shared_ptr>); + + bool is_empty() { return ROOT_ == character; } + + // remove current path from root + void remove(); + + float log_prob_b_prev; + float log_prob_nb_prev; + float log_prob_b_cur; + float log_prob_nb_cur; + float score; + float approx_ctc; + int character; + PathTrie* parent; + +private: + int ROOT_; + bool exists_; + bool has_dictionary_; + + std::vector> children_; + + // pointer to dictionary of FST + fst::StdVectorFst* dictionary_; + fst::StdVectorFst::StateId dictionary_state_; + // true if finding ars in FST + std::shared_ptr> matcher_; +}; + +#endif // PATH_TRIE_H diff --git a/deep_speech_2/decoders/swig/scorer.cpp b/deep_speech_2/decoders/swig/scorer.cpp new file mode 100644 index 0000000000..686c67c77e --- /dev/null +++ b/deep_speech_2/decoders/swig/scorer.cpp @@ -0,0 +1,234 @@ +#include "scorer.h" + +#include +#include + +#include "lm/config.hh" +#include "lm/model.hh" +#include "lm/state.hh" +#include "util/string_piece.hh" +#include "util/tokenize_piece.hh" + +#include "decoder_utils.h" + +using namespace lm::ngram; + +Scorer::Scorer(double alpha, + double beta, + const std::string& lm_path, + const std::vector& vocab_list) { + this->alpha = alpha; + this->beta = beta; + + dictionary = nullptr; + is_character_based_ = true; + language_model_ = nullptr; + + max_order_ = 0; + dict_size_ = 0; + SPACE_ID_ = -1; + + setup(lm_path, vocab_list); +} + +Scorer::~Scorer() { + if (language_model_ != nullptr) { + delete static_cast(language_model_); + } + if (dictionary != nullptr) { + delete static_cast(dictionary); + } +} + +void Scorer::setup(const std::string& lm_path, + const std::vector& vocab_list) { + // load language model + load_lm(lm_path); + // set char map for scorer + set_char_map(vocab_list); + // fill the dictionary for FST + if (!is_character_based()) { + fill_dictionary(true); + } +} + +void Scorer::load_lm(const std::string& lm_path) { + const char* filename = lm_path.c_str(); + VALID_CHECK_EQ(access(filename, F_OK), 0, "Invalid language model path"); + + RetriveStrEnumerateVocab enumerate; + lm::ngram::Config config; + config.enumerate_vocab = &enumerate; + language_model_ = lm::ngram::LoadVirtual(filename, config); + max_order_ = static_cast(language_model_)->Order(); + vocabulary_ = enumerate.vocabulary; + for (size_t i = 0; i < vocabulary_.size(); ++i) { + if (is_character_based_ && vocabulary_[i] != UNK_TOKEN && + vocabulary_[i] != START_TOKEN && vocabulary_[i] != END_TOKEN && + get_utf8_str_len(enumerate.vocabulary[i]) > 1) { + is_character_based_ = false; + } + } +} + +double Scorer::get_log_cond_prob(const std::vector& words) { + lm::base::Model* model = static_cast(language_model_); + double cond_prob; + lm::ngram::State state, tmp_state, out_state; + // avoid to inserting in begin + model->NullContextWrite(&state); + for (size_t i = 0; i < words.size(); ++i) { + lm::WordIndex word_index = model->BaseVocabulary().Index(words[i]); + // encounter OOV + if (word_index == 0) { + return OOV_SCORE; + } + cond_prob = model->BaseScore(&state, word_index, &out_state); + tmp_state = state; + state = out_state; + out_state = tmp_state; + } + // return log10 prob + return cond_prob; +} + +double Scorer::get_sent_log_prob(const std::vector& words) { + std::vector sentence; + if (words.size() == 0) { + for (size_t i = 0; i < max_order_; ++i) { + sentence.push_back(START_TOKEN); + } + } else { + for (size_t i = 0; i < max_order_ - 1; ++i) { + sentence.push_back(START_TOKEN); + } + sentence.insert(sentence.end(), words.begin(), words.end()); + } + sentence.push_back(END_TOKEN); + return get_log_prob(sentence); +} + +double Scorer::get_log_prob(const std::vector& words) { + assert(words.size() > max_order_); + double score = 0.0; + for (size_t i = 0; i < words.size() - max_order_ + 1; ++i) { + std::vector ngram(words.begin() + i, + words.begin() + i + max_order_); + score += get_log_cond_prob(ngram); + } + return score; +} + +void Scorer::reset_params(float alpha, float beta) { + this->alpha = alpha; + this->beta = beta; +} + +std::string Scorer::vec2str(const std::vector& input) { + std::string word; + for (auto ind : input) { + word += char_list_[ind]; + } + return word; +} + +std::vector Scorer::split_labels(const std::vector& labels) { + if (labels.empty()) return {}; + + std::string s = vec2str(labels); + std::vector words; + if (is_character_based_) { + words = split_utf8_str(s); + } else { + words = split_str(s, " "); + } + return words; +} + +void Scorer::set_char_map(const std::vector& char_list) { + char_list_ = char_list; + char_map_.clear(); + + for (size_t i = 0; i < char_list_.size(); i++) { + if (char_list_[i] == " ") { + SPACE_ID_ = i; + char_map_[' '] = i; + } else if (char_list_[i].size() == 1) { + char_map_[char_list_[i][0]] = i; + } + } +} + +std::vector Scorer::make_ngram(PathTrie* prefix) { + std::vector ngram; + PathTrie* current_node = prefix; + PathTrie* new_node = nullptr; + + for (int order = 0; order < max_order_; order++) { + std::vector prefix_vec; + + if (is_character_based_) { + new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_, 1); + current_node = new_node; + } else { + new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_); + current_node = new_node->parent; // Skipping spaces + } + + // reconstruct word + std::string word = vec2str(prefix_vec); + ngram.push_back(word); + + if (new_node->character == -1) { + // No more spaces, but still need order + for (int i = 0; i < max_order_ - order - 1; i++) { + ngram.push_back(START_TOKEN); + } + break; + } + } + std::reverse(ngram.begin(), ngram.end()); + return ngram; +} + +void Scorer::fill_dictionary(bool add_space) { + fst::StdVectorFst dictionary; + // First reverse char_list so ints can be accessed by chars + std::unordered_map char_map; + for (size_t i = 0; i < char_list_.size(); i++) { + char_map[char_list_[i]] = i; + } + + // For each unigram convert to ints and put in trie + int dict_size = 0; + for (const auto& word : vocabulary_) { + bool added = add_word_to_dictionary( + word, char_map, add_space, SPACE_ID_, &dictionary); + dict_size += added ? 1 : 0; + } + + dict_size_ = dict_size; + + /* Simplify FST + + * This gets rid of "epsilon" transitions in the FST. + * These are transitions that don't require a string input to be taken. + * Getting rid of them is necessary to make the FST determinisitc, but + * can greatly increase the size of the FST + */ + fst::RmEpsilon(&dictionary); + fst::StdVectorFst* new_dict = new fst::StdVectorFst; + + /* This makes the FST deterministic, meaning for any string input there's + * only one possible state the FST could be in. It is assumed our + * dictionary is deterministic when using it. + * (lest we'd have to check for multiple transitions at each state) + */ + fst::Determinize(dictionary, new_dict); + + /* Finds the simplest equivalent fst. This is unnecessary but decreases + * memory usage of the dictionary + */ + fst::Minimize(new_dict); + this->dictionary = new_dict; +} diff --git a/deep_speech_2/decoders/swig/scorer.h b/deep_speech_2/decoders/swig/scorer.h new file mode 100644 index 0000000000..6183646359 --- /dev/null +++ b/deep_speech_2/decoders/swig/scorer.h @@ -0,0 +1,112 @@ +#ifndef SCORER_H_ +#define SCORER_H_ + +#include +#include +#include +#include + +#include "lm/enumerate_vocab.hh" +#include "lm/virtual_interface.hh" +#include "lm/word_index.hh" +#include "util/string_piece.hh" + +#include "path_trie.h" + +const double OOV_SCORE = -1000.0; +const std::string START_TOKEN = ""; +const std::string UNK_TOKEN = ""; +const std::string END_TOKEN = ""; + +// Implement a callback to retrive the dictionary of language model. +class RetriveStrEnumerateVocab : public lm::EnumerateVocab { +public: + RetriveStrEnumerateVocab() {} + + void Add(lm::WordIndex index, const StringPiece &str) { + vocabulary.push_back(std::string(str.data(), str.length())); + } + + std::vector vocabulary; +}; + +/* External scorer to query score for n-gram or sentence, including language + * model scoring and word insertion. + * + * Example: + * Scorer scorer(alpha, beta, "path_of_language_model"); + * scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" }); + * scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" }); + */ +class Scorer { +public: + Scorer(double alpha, + double beta, + const std::string &lm_path, + const std::vector &vocabulary); + ~Scorer(); + + double get_log_cond_prob(const std::vector &words); + + double get_sent_log_prob(const std::vector &words); + + // return the max order + size_t get_max_order() const { return max_order_; } + + // return the dictionary size of language model + size_t get_dict_size() const { return dict_size_; } + + // retrun true if the language model is character based + bool is_character_based() const { return is_character_based_; } + + // reset params alpha & beta + void reset_params(float alpha, float beta); + + // make ngram for a given prefix + std::vector make_ngram(PathTrie *prefix); + + // trransform the labels in index to the vector of words (word based lm) or + // the vector of characters (character based lm) + std::vector split_labels(const std::vector &labels); + + // language model weight + double alpha; + // word insertion weight + double beta; + + // pointer to the dictionary of FST + void *dictionary; + +protected: + // necessary setup: load language model, set char map, fill FST's dictionary + void setup(const std::string &lm_path, + const std::vector &vocab_list); + + // load language model from given path + void load_lm(const std::string &lm_path); + + // fill dictionary for FST + void fill_dictionary(bool add_space); + + // set char map + void set_char_map(const std::vector &char_list); + + double get_log_prob(const std::vector &words); + + // translate the vector in index to string + std::string vec2str(const std::vector &input); + +private: + void *language_model_; + bool is_character_based_; + size_t max_order_; + size_t dict_size_; + + int SPACE_ID_; + std::vector char_list_; + std::unordered_map char_map_; + + std::vector vocabulary_; +}; + +#endif // SCORER_H_ diff --git a/deep_speech_2/decoders/swig/setup.py b/deep_speech_2/decoders/swig/setup.py new file mode 100644 index 0000000000..8af9ff3047 --- /dev/null +++ b/deep_speech_2/decoders/swig/setup.py @@ -0,0 +1,121 @@ +"""Script to build and install decoder package.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from setuptools import setup, Extension, distutils +import glob +import platform +import os, sys +import multiprocessing.pool +import argparse + +parser = argparse.ArgumentParser(description=__doc__) +parser.add_argument( + "--num_processes", + default=1, + type=int, + help="Number of cpu processes to build package. (default: %(default)d)") +args = parser.parse_known_args() + +# reconstruct sys.argv to pass to setup below +sys.argv = [sys.argv[0]] + args[1] + + +# monkey-patch for parallel compilation +# See: https://stackoverflow.com/a/13176803 +def parallelCCompile(self, + sources, + output_dir=None, + macros=None, + include_dirs=None, + debug=0, + extra_preargs=None, + extra_postargs=None, + depends=None): + # those lines are copied from distutils.ccompiler.CCompiler directly + macros, objects, extra_postargs, pp_opts, build = self._setup_compile( + output_dir, macros, include_dirs, sources, depends, extra_postargs) + cc_args = self._get_cc_args(pp_opts, debug, extra_preargs) + + # parallel code + def _single_compile(obj): + try: + src, ext = build[obj] + except KeyError: + return + self._compile(obj, src, ext, cc_args, extra_postargs, pp_opts) + + # convert to list, imap is evaluated on-demand + thread_pool = multiprocessing.pool.ThreadPool(args[0].num_processes) + list(thread_pool.imap(_single_compile, objects)) + return objects + + +def compile_test(header, library): + dummy_path = os.path.join(os.path.dirname(__file__), "dummy") + command = "bash -c \"g++ -include " + header \ + + " -l" + library + " -x c++ - <<<'int main() {}' -o " \ + + dummy_path + " >/dev/null 2>/dev/null && rm " \ + + dummy_path + " 2>/dev/null\"" + return os.system(command) == 0 + + +# hack compile to support parallel compiling +distutils.ccompiler.CCompiler.compile = parallelCCompile + +FILES = glob.glob('kenlm/util/*.cc') \ + + glob.glob('kenlm/lm/*.cc') \ + + glob.glob('kenlm/util/double-conversion/*.cc') + +FILES += glob.glob('openfst-1.6.3/src/lib/*.cc') + +# FILES + glob.glob('glog/src/*.cc') +FILES = [ + fn for fn in FILES + if not (fn.endswith('main.cc') or fn.endswith('test.cc') or fn.endswith( + 'unittest.cc')) +] + +LIBS = ['stdc++'] +if platform.system() != 'Darwin': + LIBS.append('rt') + +ARGS = ['-O3', '-DNDEBUG', '-DKENLM_MAX_ORDER=6', '-std=c++11'] + +if compile_test('zlib.h', 'z'): + ARGS.append('-DHAVE_ZLIB') + LIBS.append('z') + +if compile_test('bzlib.h', 'bz2'): + ARGS.append('-DHAVE_BZLIB') + LIBS.append('bz2') + +if compile_test('lzma.h', 'lzma'): + ARGS.append('-DHAVE_XZLIB') + LIBS.append('lzma') + +os.system('swig -python -c++ ./decoders.i') + +decoders_module = [ + Extension( + name='_swig_decoders', + sources=FILES + glob.glob('*.cxx') + glob.glob('*.cpp'), + language='c++', + include_dirs=[ + '.', + 'kenlm', + 'openfst-1.6.3/src/include', + 'ThreadPool', + #'glog/src' + ], + libraries=LIBS, + extra_compile_args=ARGS) +] + +setup( + name='swig_decoders', + version='0.1', + description="""CTC decoders""", + ext_modules=decoders_module, + py_modules=['swig_decoders'], ) diff --git a/deep_speech_2/decoders/swig/setup.sh b/deep_speech_2/decoders/swig/setup.sh new file mode 100644 index 0000000000..78ae2b2011 --- /dev/null +++ b/deep_speech_2/decoders/swig/setup.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash + +if [ ! -d kenlm ]; then + git clone https://github.com/luotao1/kenlm.git + echo -e "\n" +fi + +if [ ! -d openfst-1.6.3 ]; then + echo "Download and extract openfst ..." + wget http://www.openfst.org/twiki/pub/FST/FstDownload/openfst-1.6.3.tar.gz + tar -xzvf openfst-1.6.3.tar.gz + echo -e "\n" +fi + +if [ ! -d ThreadPool ]; then + git clone https://github.com/progschj/ThreadPool.git + echo -e "\n" +fi + +echo "Install decoders ..." +python setup.py install --num_processes 4 diff --git a/deep_speech_2/decoders/swig_wrapper.py b/deep_speech_2/decoders/swig_wrapper.py new file mode 100644 index 0000000000..0a92112582 --- /dev/null +++ b/deep_speech_2/decoders/swig_wrapper.py @@ -0,0 +1,116 @@ +"""Wrapper for various CTC decoders in SWIG.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import swig_decoders + + +class Scorer(swig_decoders.Scorer): + """Wrapper for Scorer. + + :param alpha: Parameter associated with language model. Don't use + language model when alpha = 0. + :type alpha: float + :param beta: Parameter associated with word count. Don't use word + count when beta = 0. + :type beta: float + :model_path: Path to load language model. + :type model_path: basestring + """ + + def __init__(self, alpha, beta, model_path, vocabulary): + swig_decoders.Scorer.__init__(self, alpha, beta, model_path, vocabulary) + + +def ctc_greedy_decoder(probs_seq, vocabulary): + """Wrapper for ctc best path decoder in swig. + + :param probs_seq: 2-D list of probability distributions over each time + step, with each element being a list of normalized + probabilities over vocabulary and blank. + :type probs_seq: 2-D list + :param vocabulary: Vocabulary list. + :type vocabulary: list + :return: Decoding result string. + :rtype: basestring + """ + return swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary) + + +def ctc_beam_search_decoder(probs_seq, + vocabulary, + beam_size, + cutoff_prob=1.0, + cutoff_top_n=40, + ext_scoring_func=None): + """Wrapper for the CTC Beam Search Decoder. + + :param probs_seq: 2-D list of probability distributions over each time + step, with each element being a list of normalized + probabilities over vocabulary and blank. + :type probs_seq: 2-D list + :param vocabulary: Vocabulary list. + :type vocabulary: list + :param beam_size: Width for beam search. + :type beam_size: int + :param cutoff_prob: Cutoff probability in pruning, + default 1.0, no pruning. + :type cutoff_prob: float + :param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n + characters with highest probs in vocabulary will be + used in beam search, default 40. + :type cutoff_top_n: int + :param ext_scoring_func: External scoring function for + partially decoded sentence, e.g. word count + or language model. + :type external_scoring_func: callable + :return: List of tuples of log probability and sentence as decoding + results, in descending order of the probability. + :rtype: list + """ + return swig_decoders.ctc_beam_search_decoder(probs_seq.tolist(), vocabulary, + beam_size, cutoff_prob, + cutoff_top_n, ext_scoring_func) + + +def ctc_beam_search_decoder_batch(probs_split, + vocabulary, + beam_size, + num_processes, + cutoff_prob=1.0, + cutoff_top_n=40, + ext_scoring_func=None): + """Wrapper for the batched CTC beam search decoder. + + :param probs_seq: 3-D list with each element as an instance of 2-D list + of probabilities used by ctc_beam_search_decoder(). + :type probs_seq: 3-D list + :param vocabulary: Vocabulary list. + :type vocabulary: list + :param beam_size: Width for beam search. + :type beam_size: int + :param num_processes: Number of parallel processes. + :type num_processes: int + :param cutoff_prob: Cutoff probability in vocabulary pruning, + default 1.0, no pruning. + :type cutoff_prob: float + :param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n + characters with highest probs in vocabulary will be + used in beam search, default 40. + :type cutoff_top_n: int + :param num_processes: Number of parallel processes. + :type num_processes: int + :param ext_scoring_func: External scoring function for + partially decoded sentence, e.g. word count + or language model. + :type external_scoring_function: callable + :return: List of tuples of log probability and sentence as decoding + results, in descending order of the probability. + :rtype: list + """ + probs_split = [probs_seq.tolist() for probs_seq in probs_split] + + return swig_decoders.ctc_beam_search_decoder_batch( + probs_split, vocabulary, beam_size, num_processes, cutoff_prob, + cutoff_top_n, ext_scoring_func) diff --git a/deep_speech_2/model_utils/tests/test_decoders.py b/deep_speech_2/decoders/tests/test_decoders.py similarity index 93% rename from deep_speech_2/model_utils/tests/test_decoders.py rename to deep_speech_2/decoders/tests/test_decoders.py index adf36eefc3..d522b5efa0 100644 --- a/deep_speech_2/model_utils/tests/test_decoders.py +++ b/deep_speech_2/decoders/tests/test_decoders.py @@ -4,7 +4,7 @@ from __future__ import print_function import unittest -from model_utils import decoder +from decoders import decoders_deprecated as decoder class TestDecoders(unittest.TestCase): @@ -66,16 +66,14 @@ def test_beam_search_decoder_1(self): beam_result = decoder.ctc_beam_search_decoder( probs_seq=self.probs_seq1, beam_size=self.beam_size, - vocabulary=self.vocab_list, - blank_id=len(self.vocab_list)) + vocabulary=self.vocab_list) self.assertEqual(beam_result[0][1], self.beam_search_result[0]) def test_beam_search_decoder_2(self): beam_result = decoder.ctc_beam_search_decoder( probs_seq=self.probs_seq2, beam_size=self.beam_size, - vocabulary=self.vocab_list, - blank_id=len(self.vocab_list)) + vocabulary=self.vocab_list) self.assertEqual(beam_result[0][1], self.beam_search_result[1]) def test_beam_search_decoder_batch(self): @@ -83,7 +81,6 @@ def test_beam_search_decoder_batch(self): probs_split=[self.probs_seq1, self.probs_seq2], beam_size=self.beam_size, vocabulary=self.vocab_list, - blank_id=len(self.vocab_list), num_processes=24) self.assertEqual(beam_results[0][0][1], self.beam_search_result[0]) self.assertEqual(beam_results[1][0][1], self.beam_search_result[1]) diff --git a/deep_speech_2/examples/librispeech/run_infer.sh b/deep_speech_2/examples/librispeech/run_infer.sh index eb812440be..85587ed476 100644 --- a/deep_speech_2/examples/librispeech/run_infer.sh +++ b/deep_speech_2/examples/librispeech/run_infer.sh @@ -21,9 +21,10 @@ python -u infer.py \ --num_conv_layers=2 \ --num_rnn_layers=3 \ --rnn_layer_size=2048 \ ---alpha=0.36 \ ---beta=0.25 \ ---cutoff_prob=0.99 \ +--alpha=2.15 \ +--beta=0.35 \ +--cutoff_prob=1.0 \ +--cutoff_top_n=40 \ --use_gru=False \ --use_gpu=True \ --share_rnn_weights=True \ diff --git a/deep_speech_2/examples/librispeech/run_infer_golden.sh b/deep_speech_2/examples/librispeech/run_infer_golden.sh index eeccfdebbc..8feca555ed 100644 --- a/deep_speech_2/examples/librispeech/run_infer_golden.sh +++ b/deep_speech_2/examples/librispeech/run_infer_golden.sh @@ -30,9 +30,10 @@ python -u infer.py \ --num_conv_layers=2 \ --num_rnn_layers=3 \ --rnn_layer_size=2048 \ ---alpha=0.36 \ ---beta=0.25 \ ---cutoff_prob=0.99 \ +--alpha=2.15 \ +--beta=0.35 \ +--cutoff_prob=1.0 \ +--cutoff_top_n=40 \ --use_gru=False \ --use_gpu=True \ --share_rnn_weights=True \ diff --git a/deep_speech_2/examples/librispeech/run_test.sh b/deep_speech_2/examples/librispeech/run_test.sh index 7ef06ba9fd..d75848b00b 100644 --- a/deep_speech_2/examples/librispeech/run_test.sh +++ b/deep_speech_2/examples/librispeech/run_test.sh @@ -22,9 +22,9 @@ python -u test.py \ --num_conv_layers=2 \ --num_rnn_layers=3 \ --rnn_layer_size=2048 \ ---alpha=0.36 \ ---beta=0.25 \ ---cutoff_prob=0.99 \ +--alpha=2.15 \ +--beta=0.35 \ +--cutoff_prob=1.0 \ --use_gru=False \ --use_gpu=True \ --share_rnn_weights=True \ diff --git a/deep_speech_2/examples/librispeech/run_test_golden.sh b/deep_speech_2/examples/librispeech/run_test_golden.sh index 86fe15306a..352a941565 100644 --- a/deep_speech_2/examples/librispeech/run_test_golden.sh +++ b/deep_speech_2/examples/librispeech/run_test_golden.sh @@ -31,9 +31,10 @@ python -u test.py \ --num_conv_layers=2 \ --num_rnn_layers=3 \ --rnn_layer_size=2048 \ ---alpha=0.36 \ ---beta=0.25 \ ---cutoff_prob=0.99 \ +--alpha=2.15 \ +--beta=0.35 \ +--cutoff_prob=1.0 \ +--cutoff_top_n=40 \ --use_gru=False \ --use_gpu=True \ --share_rnn_weights=True \ diff --git a/deep_speech_2/examples/tiny/run_infer.sh b/deep_speech_2/examples/tiny/run_infer.sh index dafc99d9c5..85b083a277 100644 --- a/deep_speech_2/examples/tiny/run_infer.sh +++ b/deep_speech_2/examples/tiny/run_infer.sh @@ -21,9 +21,9 @@ python -u infer.py \ --num_conv_layers=2 \ --num_rnn_layers=3 \ --rnn_layer_size=2048 \ ---alpha=0.36 \ ---beta=0.25 \ ---cutoff_prob=0.99 \ +--alpha=2.15 \ +--beta=0.35 \ +--cutoff_prob=1.0 \ --use_gru=False \ --use_gpu=True \ --share_rnn_weights=True \ diff --git a/deep_speech_2/examples/tiny/run_infer_golden.sh b/deep_speech_2/examples/tiny/run_infer_golden.sh index 66360a6917..3ee2f9aefe 100644 --- a/deep_speech_2/examples/tiny/run_infer_golden.sh +++ b/deep_speech_2/examples/tiny/run_infer_golden.sh @@ -30,9 +30,9 @@ python -u infer.py \ --num_conv_layers=2 \ --num_rnn_layers=3 \ --rnn_layer_size=2048 \ ---alpha=0.36 \ ---beta=0.25 \ ---cutoff_prob=0.99 \ +--alpha=2.15 \ +--beta=0.35 \ +--cutoff_prob=1.0 \ --use_gru=False \ --use_gpu=True \ --share_rnn_weights=True \ diff --git a/deep_speech_2/examples/tiny/run_test.sh b/deep_speech_2/examples/tiny/run_test.sh index 70cf4bfe2e..063076328a 100644 --- a/deep_speech_2/examples/tiny/run_test.sh +++ b/deep_speech_2/examples/tiny/run_test.sh @@ -22,9 +22,9 @@ python -u test.py \ --num_conv_layers=2 \ --num_rnn_layers=3 \ --rnn_layer_size=2048 \ ---alpha=0.36 \ ---beta=0.25 \ ---cutoff_prob=0.99 \ +--alpha=2.15 \ +--beta=0.35 \ +--cutoff_prob=1.0 \ --use_gru=False \ --use_gpu=True \ --share_rnn_weights=True \ diff --git a/deep_speech_2/examples/tiny/run_test_golden.sh b/deep_speech_2/examples/tiny/run_test_golden.sh index e188c81b3f..351cb87cb7 100644 --- a/deep_speech_2/examples/tiny/run_test_golden.sh +++ b/deep_speech_2/examples/tiny/run_test_golden.sh @@ -31,9 +31,9 @@ python -u test.py \ --num_conv_layers=2 \ --num_rnn_layers=3 \ --rnn_layer_size=2048 \ ---alpha=0.36 \ ---beta=0.25 \ ---cutoff_prob=0.99 \ +--alpha=2.15 \ +--beta=0.35 \ +--cutoff_prob=1.0 \ --use_gru=False \ --use_gpu=True \ --share_rnn_weights=True \ diff --git a/deep_speech_2/infer.py b/deep_speech_2/infer.py index d9c4c67763..e635f6d0f9 100644 --- a/deep_speech_2/infer.py +++ b/deep_speech_2/infer.py @@ -21,9 +21,10 @@ add_arg('num_conv_layers', int, 2, "# of convolution layers.") add_arg('num_rnn_layers', int, 3, "# of recurrent layers.") add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.") -add_arg('alpha', float, 0.36, "Coef of LM for beam search.") -add_arg('beta', float, 0.25, "Coef of WC for beam search.") -add_arg('cutoff_prob', float, 0.99, "Cutoff probability for pruning.") +add_arg('alpha', float, 2.15, "Coef of LM for beam search.") +add_arg('beta', float, 0.35, "Coef of WC for beam search.") +add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.") +add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.") add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.") add_arg('use_gpu', bool, True, "Use GPU or not.") add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across " @@ -84,6 +85,10 @@ def infer(): use_gru=args.use_gru, pretrained_model_path=args.model_path, share_rnn_weights=args.share_rnn_weights) + + # decoders only accept string encoded in utf-8 + vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list] + result_transcripts = ds2_model.infer_batch( infer_data=infer_data, decoding_method=args.decoding_method, @@ -91,7 +96,8 @@ def infer(): beam_beta=args.beta, beam_size=args.beam_size, cutoff_prob=args.cutoff_prob, - vocab_list=data_generator.vocab_list, + cutoff_top_n=args.cutoff_top_n, + vocab_list=vocab_list, language_model_path=args.lang_model_path, num_processes=args.num_proc_bsearch) @@ -106,6 +112,7 @@ def infer(): print("Current error rate [%s] = %f" % (args.error_rate_type, error_rate_func(target, result))) + ds2_model.logger.info("finish inference") def main(): print_arguments(args) diff --git a/deep_speech_2/model_utils/model.py b/deep_speech_2/model_utils/model.py index a7c08ba5e7..67a41bd111 100644 --- a/deep_speech_2/model_utils/model.py +++ b/deep_speech_2/model_utils/model.py @@ -6,14 +6,18 @@ import sys import os import time +import logging import gzip from distutils.dir_util import mkpath import paddle.v2 as paddle -from model_utils.lm_scorer import LmScorer -from model_utils.decoder import ctc_greedy_decoder, ctc_beam_search_decoder -from model_utils.decoder import ctc_beam_search_decoder_batch +from decoders.swig_wrapper import Scorer +from decoders.swig_wrapper import ctc_greedy_decoder +from decoders.swig_wrapper import ctc_beam_search_decoder_batch from model_utils.network import deep_speech_v2_network +logging.basicConfig( + format='[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s') + class DeepSpeech2Model(object): """DeepSpeech2Model class. @@ -44,6 +48,8 @@ def __init__(self, vocab_size, num_conv_layers, num_rnn_layers, self._inferer = None self._loss_inferer = None self._ext_scorer = None + self.logger = logging.getLogger("") + self.logger.setLevel(level=logging.INFO) def train(self, train_batch_reader, @@ -157,8 +163,8 @@ def infer_loss_batch(self, infer_data): return self._loss_inferer.infer(input=infer_data) def infer_batch(self, infer_data, decoding_method, beam_alpha, beam_beta, - beam_size, cutoff_prob, vocab_list, language_model_path, - num_processes): + beam_size, cutoff_prob, cutoff_top_n, vocab_list, + language_model_path, num_processes): """Model inference. Infer the transcription for a batch of speech utterances. @@ -178,6 +184,10 @@ def infer_batch(self, infer_data, decoding_method, beam_alpha, beam_beta, :param cutoff_prob: Cutoff probability in pruning, default 1.0, no pruning. :type cutoff_prob: float + :param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n + characters with highest probs in vocabulary will be + used in beam search, default 40. + :type cutoff_top_n: int :param vocab_list: List of tokens in the vocabulary, for decoding. :type vocab_list: list :param language_model_path: Filepath for language model. @@ -209,21 +219,33 @@ def infer_batch(self, infer_data, decoding_method, beam_alpha, beam_beta, elif decoding_method == "ctc_beam_search": # initialize external scorer if self._ext_scorer == None: - self._ext_scorer = LmScorer(beam_alpha, beam_beta, - language_model_path) self._loaded_lm_path = language_model_path + self.logger.info("begin to initialize the external scorer " + "for decoding") + self._ext_scorer = Scorer(beam_alpha, beam_beta, + language_model_path, vocab_list) + + lm_char_based = self._ext_scorer.is_character_based() + lm_max_order = self._ext_scorer.get_max_order() + lm_dict_size = self._ext_scorer.get_dict_size() + self.logger.info("language model: " + "is_character_based = %d," % lm_char_based + + " max_order = %d," % lm_max_order + + " dict_size = %d" % lm_dict_size) + self.logger.info("end initializing scorer. Start decoding ...") else: self._ext_scorer.reset_params(beam_alpha, beam_beta) assert self._loaded_lm_path == language_model_path # beam search decode + num_processes = min(num_processes, len(probs_split)) beam_search_results = ctc_beam_search_decoder_batch( probs_split=probs_split, vocabulary=vocab_list, beam_size=beam_size, - blank_id=len(vocab_list), num_processes=num_processes, ext_scoring_func=self._ext_scorer, - cutoff_prob=cutoff_prob) + cutoff_prob=cutoff_prob, + cutoff_top_n=cutoff_top_n) results = [result[0][1] for result in beam_search_results] else: diff --git a/deep_speech_2/requirements.txt b/deep_speech_2/requirements.txt index 131f75ff47..e104f633c7 100644 --- a/deep_speech_2/requirements.txt +++ b/deep_speech_2/requirements.txt @@ -2,4 +2,3 @@ scipy==0.13.1 resampy==0.1.5 SoundFile==0.9.0.post1 python_speech_features -https://github.com/luotao1/kenlm/archive/master.zip diff --git a/deep_speech_2/setup.sh b/deep_speech_2/setup.sh index 15c6e1e25e..7c40415db3 100644 --- a/deep_speech_2/setup.sh +++ b/deep_speech_2/setup.sh @@ -1,4 +1,4 @@ -#! /usr/bin/env bash +#! /usr/bin/env bash # install python dependencies if [ -f "requirements.txt" ]; then @@ -20,10 +20,19 @@ if [ $? != 0 ]; then fi tar -zxvf libsndfile-1.0.28.tar.gz cd libsndfile-1.0.28 - ./configure && make && make install + ./configure > /dev/null && make > /dev/null && make install > /dev/null cd .. rm -rf libsndfile-1.0.28 rm libsndfile-1.0.28.tar.gz fi +# install decoders +python -c "import swig_decoders" +if [ $? != 0 ]; then + cd decoders/swig > /dev/null + sh setup.sh + cd - > /dev/null +fi + + echo "Install all dependencies successfully." diff --git a/deep_speech_2/test.py b/deep_speech_2/test.py index 18089f3325..40f0795a13 100644 --- a/deep_speech_2/test.py +++ b/deep_speech_2/test.py @@ -22,9 +22,10 @@ add_arg('num_conv_layers', int, 2, "# of convolution layers.") add_arg('num_rnn_layers', int, 3, "# of recurrent layers.") add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.") -add_arg('alpha', float, 0.36, "Coef of LM for beam search.") -add_arg('beta', float, 0.25, "Coef of WC for beam search.") -add_arg('cutoff_prob', float, 0.99, "Cutoff probability for pruning.") +add_arg('alpha', float, 2.15, "Coef of LM for beam search.") +add_arg('beta', float, 0.35, "Coef of WC for beam search.") +add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.") +add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.") add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.") add_arg('use_gpu', bool, True, "Use GPU or not.") add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across " @@ -85,6 +86,9 @@ def evaluate(): pretrained_model_path=args.model_path, share_rnn_weights=args.share_rnn_weights) + # decoders only accept string encoded in utf-8 + vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list] + error_rate_func = cer if args.error_rate_type == 'cer' else wer error_sum, num_ins = 0.0, 0 for infer_data in batch_reader(): @@ -95,7 +99,8 @@ def evaluate(): beam_beta=args.beta, beam_size=args.beam_size, cutoff_prob=args.cutoff_prob, - vocab_list=data_generator.vocab_list, + cutoff_top_n=args.cutoff_top_n, + vocab_list=vocab_list, language_model_path=args.lang_model_path, num_processes=args.num_proc_bsearch) target_transcripts = [ @@ -110,6 +115,7 @@ def evaluate(): print("Final error rate [%s] (%d/%d) = %f" % (args.error_rate_type, num_ins, num_ins, error_sum / num_ins)) + ds2_model.logger.info("finish evaluation") def main(): print_arguments(args)