Skip to content

Commit

Permalink
[src] nnet3: extend what descriptors can be parsed. (#2780)
Browse files Browse the repository at this point in the history
  • Loading branch information
danpovey authored Oct 13, 2018
1 parent 087c21f commit a10e56e
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/nnet3/nnet-descriptor-test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ void UnitTestGeneralDescriptorSpecial() {
names.push_back("d");
KALDI_ASSERT(NormalizeTextDescriptor(names, "a") == "a");
KALDI_ASSERT(NormalizeTextDescriptor(names, "Scale(-1.0, a)") == "Scale(-1, a)");
KALDI_ASSERT(NormalizeTextDescriptor(names, "Scale(-1.0, Scale(-2.0, a))") == "Scale(2, a)");
KALDI_ASSERT(NormalizeTextDescriptor(names, "Scale(2.0, Sum(Scale(2.0, a), b, c))") ==
"Sum(Scale(4, a), Sum(Scale(2, b), Scale(2, c)))");
KALDI_ASSERT(NormalizeTextDescriptor(names, "Const(1.0, 512)") == "Const(1, 512)");
KALDI_ASSERT(NormalizeTextDescriptor(names, "Sum(Const(1.0, 512), Scale(-1.0, a))") ==
"Sum(Const(1, 512), Scale(-1, a))");
Expand Down
23 changes: 23 additions & 0 deletions src/nnet3/nnet-descriptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,29 @@ bool GeneralDescriptor::Normalize(GeneralDescriptor *desc) {
std::swap(desc->value1_, child->value1_);
std::swap(desc->value2_, child->value2_);
changed = true;
} else if (child->descriptor_type_ == kSum) {
// Push the Scale() inside the sum expression.
desc->descriptors_.clear();
for (size_t i = 0; i < child->descriptors_.size(); i++) {
GeneralDescriptor *new_child =
new GeneralDescriptor(kScale, -1, -1, desc->alpha_);
new_child->descriptors_.push_back(child->descriptors_[i]);
desc->descriptors_.push_back(new_child);
}
desc->descriptor_type_ = kSum;
desc->alpha_ = 0.0;
child->descriptors_.clear(); // prevent them being freed.
delete child;
changed = true;
} else if (child->descriptor_type_ == kScale) {
// Combine the 'scale' expressions.
KALDI_ASSERT(child->descriptors_.size() == 1);
GeneralDescriptor *grandchild = child->descriptors_[0];
desc->alpha_ *= child->alpha_;
desc->descriptors_[0] = grandchild;
child->descriptors_.clear(); // prevent them being freed.
delete child;
changed = true;
} else if (child->descriptor_type_ != kNodeName) {
KALDI_ERR << "Unhandled case encountered when normalizing Descriptor; "
"you can work around this by pushing Scale() inside "
Expand Down

0 comments on commit a10e56e

Please sign in to comment.