Skip to content

Commit

Permalink
compute probs in log format
Browse files Browse the repository at this point in the history
  • Loading branch information
Yibing Liu committed Jul 20, 2017
1 parent ae05535 commit ac18ee5
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 36 deletions.
90 changes: 57 additions & 33 deletions deep_speech_2/deploy/ctc_beam_search_decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <algorithm>
#include <utility>
#include <cmath>
#include <float.h>
#include "ctc_beam_search_decoder.h"

template <typename T1, typename T2>
Expand All @@ -17,6 +18,14 @@ bool pair_comp_second_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b)
return a.second > b.second;
}

template <typename T>
T log_sum_exp(T x, T y)
{
T xmax = std::max(x, y);
return std::log(std::exp(x-xmax) + std::exp(y-xmax)) + xmax;
}


std::vector<std::pair<double, std::string> >
ctc_beam_search_decoder(std::vector<std::vector<double> > probs_seq,
int beam_size,
Expand Down Expand Up @@ -54,80 +63,94 @@ std::vector<std::pair<double, std::string> >
// two sets containing selected and candidate prefixes respectively
std::map<std::string, double> prefix_set_prev, prefix_set_next;
// probability of prefixes ending with blank and non-blank
std::map<std::string, double> probs_b_prev, probs_nb_prev;
std::map<std::string, double> probs_b_cur, probs_nb_cur;
prefix_set_prev["\t"] = 1.0;
probs_b_prev["\t"] = 1.0;
probs_nb_prev["\t"] = 0.0;
std::map<std::string, double> log_probs_b_prev, log_probs_nb_prev;
std::map<std::string, double> log_probs_b_cur, log_probs_nb_cur;
prefix_set_prev["\t"] = 0.0;
log_probs_b_prev["\t"] = 0.0;
log_probs_nb_prev["\t"] = FLT_MIN;

for (int time_step=0; time_step<num_time_steps; time_step++) {
prefix_set_next.clear();
probs_b_cur.clear();
probs_nb_cur.clear();
log_probs_b_cur.clear();
log_probs_nb_cur.clear();
std::vector<double> prob = probs_seq[time_step];

std::vector<std::pair<int, double> > prob_idx;
for (int i=0; i<prob.size(); i++) {
prob_idx.push_back(std::pair<int, double>(i, prob[i]));
}
// pruning of vacobulary
int cutoff_len = prob.size();
if (cutoff_prob < 1.0) {
std::sort(prob_idx.begin(), prob_idx.end(),
pair_comp_second_rev<int, double>);
float cum_prob = 0.0;
int cutoff_len = 0;
cutoff_len = 0;
for (int i=0; i<prob_idx.size(); i++) {
cum_prob += prob_idx[i].second;
cutoff_len += 1;
if (cum_prob >= cutoff_prob) break;
}
prob_idx = std::vector<std::pair<int, double> >( prob_idx.begin(),
prob_idx.begin() + cutoff_len);
prob_idx.begin() + cutoff_len);
}

std::vector<std::pair<int, float> > log_prob_idx;
for (int i=0; i<cutoff_len; i++) {
log_prob_idx.push_back(std::pair<int, float>
(prob_idx[i].first, log(prob_idx[i].second)));
}
// extend prefix
for (std::map<std::string, double>::iterator it = prefix_set_prev.begin();
for (std::map<std::string, double>::iterator
it = prefix_set_prev.begin();
it != prefix_set_prev.end(); it++) {
std::string l = it->first;
if( prefix_set_next.find(l) == prefix_set_next.end()) {
probs_b_cur[l] = probs_nb_cur[l] = 0.0;
log_probs_b_cur[l] = log_probs_nb_cur[l] = FLT_MIN;
}

for (int index=0; index<prob_idx.size(); index++) {
int c = prob_idx[index].first;
double prob_c = prob_idx[index].second;
std::cout<<"l = "<<l<<", log_b_cur = "<<log_probs_b_prev[l]<<", log_nb_cur = "<<log_probs_nb_prev[l]<<std::endl;
for (int index=0; index<log_prob_idx.size(); index++) {
int c = log_prob_idx[index].first;
float log_prob_c = log_prob_idx[index].second;
float log_probs_prev;
if (c == blank_id) {
probs_b_cur[l] += prob_c * (probs_b_prev[l] + probs_nb_prev[l]);
log_probs_prev = log_sum_exp<float>(log_probs_b_prev[l],
log_probs_nb_prev[l]);
log_probs_b_cur[l] = log_sum_exp<float>(log_probs_b_cur[l],
log_prob_c+log_probs_prev);
} else {
std::string last_char = l.substr(l.size()-1, 1);
std::string new_char = vocabulary[c];
std::string l_plus = l + new_char;

if( prefix_set_next.find(l_plus) == prefix_set_next.end()) {
probs_b_cur[l_plus] = probs_nb_cur[l_plus] = 0.0;
log_probs_b_cur[l_plus] = FLT_MIN;
log_probs_nb_cur[l_plus] = FLT_MIN;
}
if (last_char == new_char) {
probs_nb_cur[l_plus] += prob_c * probs_b_prev[l];
probs_nb_cur[l] += prob_c * probs_nb_prev[l];
log_probs_nb_cur[l_plus] = log_sum_exp<float>(log_probs_nb_cur[l_plus], log_prob_c+log_probs_b_prev[l]);
log_probs_nb_cur[l] = log_sum_exp<float>(log_probs_nb_cur[l], log_prob_c+log_probs_nb_prev[l]);
} else if (new_char == " ") {
double score = 1.0;
float score = 0.0;
if (ext_scorer != NULL && l.size() > 1) {
score = ext_scorer->get_score(l.substr(1));
score = ext_scorer->get_score(l.substr(1), true);
}
probs_nb_cur[l_plus] += score * prob_c * (
probs_b_prev[l] + probs_nb_prev[l]);
log_probs_prev = log_sum_exp<float>(log_probs_b_prev[l], log_probs_nb_prev[l]);
log_probs_nb_cur[l_plus] = log_sum_exp<float>(log_probs_nb_cur[l_plus], score+log_prob_c+log_probs_prev);
} else {
probs_nb_cur[l_plus] += prob_c * (
probs_b_prev[l] + probs_nb_prev[l]);
log_probs_prev = log_sum_exp<float>(log_probs_b_prev[l], log_probs_nb_prev[l]);
log_probs_nb_cur[l_plus] = log_sum_exp<float>(log_probs_nb_cur[l_plus], log_prob_c+log_probs_prev);
}
prefix_set_next[l_plus] = probs_nb_cur[l_plus] + probs_b_cur[l_plus];
prefix_set_next[l_plus] = log_sum_exp<float>(log_probs_nb_cur[l_plus], log_probs_b_cur[l_plus]);
}
}

prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l];
prefix_set_next[l] = log_sum_exp<float>(log_probs_b_cur[l], log_probs_nb_cur[l]);
}

probs_b_prev = probs_b_cur;
probs_nb_prev = probs_nb_cur;
log_probs_b_prev = log_probs_b_cur;
log_probs_nb_prev = log_probs_nb_cur;
std::vector<std::pair<std::string, double> >
prefix_vec_next(prefix_set_next.begin(),
prefix_set_next.end());
Expand All @@ -143,15 +166,16 @@ std::vector<std::pair<double, std::string> >
std::vector<std::pair<double, std::string> > beam_result;
for (std::map<std::string, double>::iterator it = prefix_set_prev.begin();
it != prefix_set_prev.end(); it++) {
if (it->second > 0.0 && it->first.size() > 1) {
double prob = it->second;
if (it->second > FLT_MIN && it->first.size() > 1) {
float log_prob = it->second;
std::string sentence = it->first.substr(1);
// scoring the last word
if (ext_scorer != NULL && sentence[sentence.size()-1] != ' ') {
prob = prob * ext_scorer->get_score(sentence);
log_prob = log_prob + ext_scorer->get_score(sentence, true);
}
if (log_prob > FLT_MIN) {
beam_result.push_back(std::pair<double, std::string>(log_prob, sentence));
}
double log_prob = log(prob);
beam_result.push_back(std::pair<double, std::string>(log_prob, sentence));
}
}
// sort the result and return
Expand Down
9 changes: 7 additions & 2 deletions deep_speech_2/deploy/scorer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,15 @@ void Scorer::reset_params(float alpha, float beta) {
this->_beta = beta;
}

double Scorer::get_score(std::string sentence) {
double Scorer::get_score(std::string sentence, bool log) {
double lm_score = language_model_score(sentence);
int word_cnt = word_count(sentence);

double final_score = pow(10, _alpha*lm_score) * pow(word_cnt, _beta);
double final_score = 0.0;
if (log == false) {
final_score = pow(10, _alpha*lm_score) * pow(word_cnt, _beta);
} else {
final_score = _alpha*lm_score*std::log(10) + _beta*std::log(word_cnt);
}
return final_score;
}
2 changes: 1 addition & 1 deletion deep_speech_2/deploy/scorer.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class Scorer{
// reset params alpha & beta
void reset_params(float alpha, float beta);
// get the final score
double get_score(std::string);
double get_score(std::string, bool log=false);
};

#endif //SCORER_H_

0 comments on commit ac18ee5

Please sign in to comment.