Skip to content

Commit

Permalink
Merge pull request kaldi-asr#23 from david-ryan-snyder/xvector2
Browse files Browse the repository at this point in the history
Xvector: another update to nnet3-xvector-compute
  • Loading branch information
danpovey committed Mar 1, 2016
2 parents 59ed8da + 14059ad commit 0f443a3
Showing 1 changed file with 24 additions and 11 deletions.
35 changes: 24 additions & 11 deletions src/xvectorbin/nnet3-xvector-compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ int main(int argc, char *argv[]) {
<< "be true.";
if (chunk_size == -1)
chunk_size = xvector_period;

KALDI_ASSERT(chunk_size > 0 && xvector_period > 0);

std::string nnet_rxfilename = po.GetArg(1),
Expand All @@ -102,30 +103,42 @@ int main(int argc, char *argv[]) {

int32 num_success = 0,
num_fail = 0,
left_context,
right_context,
xvector_dim = nnet.OutputDim("output");
ComputeSimpleNnetContext(nnet, &left_context, &right_context);
int32 min_chunk_size = left_context + right_context;
int64 frame_count = 0;

SequentialBaseFloatMatrixReader feat_reader(feat_rspecifier);
for (; !feat_reader.Done(); feat_reader.Next()) {
std::string utt = feat_reader.Key();
const Matrix<BaseFloat> &feats (feat_reader.Value());
int32 num_rows = feats.NumRows(),
feat_dim = feats.NumCols();
if (num_rows < chunk_size) {
KALDI_WARN << "The chunk-size is greater than the number of rows "
feat_dim = feats.NumCols(),
this_chunk_size = chunk_size;

if (num_rows < min_chunk_size) {
KALDI_WARN << "Minimum chunk size of " << min_chunk_size
<< " is greater than the number of rows "
<< "in utterance: " << utt;
num_fail++;
continue;
} else if (num_rows < this_chunk_size) {
KALDI_LOG << "Chunk size of " << this_chunk_size << " is greater than "
<< "the number of rows in utterance: " << utt
<< ", using chunk size of " << num_rows;
this_chunk_size = num_rows;
}

int32 num_chunks = ceil((num_rows - chunk_size)
int32 num_chunks = ceil((num_rows - this_chunk_size)
/ static_cast<BaseFloat>(xvector_period)) + 1,
num_xvectors = repeat ? num_rows : ceil(num_rows
/ static_cast<BaseFloat>(xvector_period));

// The number of frames by which the last two chunks overlap.
int32 overlap = std::max(0, (num_chunks - 1) * xvector_period
- num_rows + chunk_size);
- num_rows + this_chunk_size);
BaseFloat total_chunk_weight = 0.0;
Vector<BaseFloat> xvector_avg;
Matrix<BaseFloat> xvector_mat;
Expand All @@ -140,11 +153,11 @@ int main(int argc, char *argv[]) {
// Iterate over the feature chunks.
for (int32 chunk_indx = 0; chunk_indx < num_chunks; chunk_indx++) {
// If we're nearing the end of the input, we may need to shift the
// offset back so that we can get chunk_size frames of input to the
// nnet.
// offset back so that we can get this_chunk_size frames of input to
// the nnet.
int32 offset = std::min(chunk_indx * xvector_period,
num_rows - chunk_size);
SubMatrix<BaseFloat> sub_feats(feats, offset, chunk_size,
num_rows - this_chunk_size);
SubMatrix<BaseFloat> sub_feats(feats, offset, this_chunk_size,
0, feat_dim);
Vector<BaseFloat> xvector(xvector_dim);
nnet_computer.ComputeXvector(sub_feats, &xvector);
Expand All @@ -155,9 +168,9 @@ int main(int argc, char *argv[]) {
// chunks, so that the overlapping portion isn't counted twice.
BaseFloat weight;
if (chunk_indx < num_chunks - 2)
weight = chunk_size;
weight = this_chunk_size;
else
weight = chunk_size - 0.5 * overlap;
weight = this_chunk_size - 0.5 * overlap;
total_chunk_weight += weight;
xvector_avg.AddVec(weight, xvector);
// Cases for outputting as a matrix:
Expand Down

0 comments on commit 0f443a3

Please sign in to comment.