Skip to content

Commit

Permalink
Merge pull request #21 from varunagrawal/feature/hybrid-bayes-net
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Feb 17, 2022
2 parents c20832e + 748687e commit ae02498
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 39 deletions.
51 changes: 29 additions & 22 deletions gtsam/hybrid/IncrementalHybrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
*/

#include <gtsam/hybrid/IncrementalHybrid.h>
#include <unordered_set>

#include <algorithm>
#include <unordered_set>

void gtsam::IncrementalHybrid::update(gtsam::GaussianHybridFactorGraph graph,
const gtsam::Ordering &ordering,
Expand All @@ -32,23 +33,30 @@ void gtsam::IncrementalHybrid::update(gtsam::GaussianHybridFactorGraph graph,
for (auto &&conditional : *hybridBayesNet_) {
for (auto &key : conditional->frontals()) {
if (allVars.find(key) != allVars.end()) {
if (auto
gf = boost::dynamic_pointer_cast<GaussianMixture>(conditional)) {
if (auto gf =
boost::dynamic_pointer_cast<GaussianMixture>(conditional)) {
graph.push_back(gf);
} else if (auto df =
boost::dynamic_pointer_cast<DiscreteConditional>(conditional)) {
} else if (auto df = boost::dynamic_pointer_cast<DiscreteConditional>(
conditional)) {
graph.push_back(df);
}
break;
}
}
}
} else {
// Initialize an empty HybridBayesNet
hybridBayesNet_ = boost::make_shared<HybridBayesNet>();
}

// Eliminate partially.
std::tie(hybridBayesNet_, remainingFactorGraph_) =
HybridBayesNet::shared_ptr bayesNetFragment;
std::tie(bayesNetFragment, remainingFactorGraph_) =
graph.eliminatePartialSequential(ordering);

// Add the partial bayes net to the posterior bayes net.
hybridBayesNet_->push_back<HybridBayesNet>(*bayesNetFragment);

// Prune
if (maxNrLeaves) {
const auto N = *maxNrLeaves;
Expand All @@ -62,17 +70,15 @@ void gtsam::IncrementalHybrid::update(gtsam::GaussianHybridFactorGraph graph,
// Let's assume that the structure of the last discrete density will be the
// same as the last continuous
std::vector<double> probabilities;
// TODO(fan): The number of probabilities can be lower than the actual number of choices
discreteFactor->visit([&](const double &prob) {
probabilities.emplace_back(prob);
});
// TODO(fan): The number of probabilities can be lower than the actual
// number of choices
discreteFactor->visit(
[&](const double &prob) { probabilities.emplace_back(prob); });

if (probabilities.size() < N) return;

std::nth_element(probabilities.begin(),
probabilities.begin() + N,
probabilities.end(),
std::greater<double>{});
std::nth_element(probabilities.begin(), probabilities.begin() + N,
probabilities.end(), std::greater<double>{});

auto thresholdValue = probabilities[N - 1];

Expand All @@ -83,14 +89,16 @@ void gtsam::IncrementalHybrid::update(gtsam::GaussianHybridFactorGraph graph,
DecisionTree<Key, double> thresholded(*discreteFactor, threshold);

// Create a new factor with pruned tree
// DecisionTreeFactor newFactor(discreteFactor->discreteKeys(), thresholded);
// DecisionTreeFactor newFactor(discreteFactor->discreteKeys(),
// thresholded);
discreteFactor->root_ = thresholded.root_;

std::vector<std::pair<DiscreteValues, double>> assignments = discreteFactor->enumerate();
std::vector<std::pair<DiscreteValues, double>> assignments =
discreteFactor->enumerate();

// Loop over all assignments and create a vector of GaussianConditionals
std::vector<GaussianFactor::shared_ptr> prunedConditionals;
for (auto && av : assignments) {
for (auto &&av : assignments) {
const DiscreteValues &assignment = av.first;
const double value = av.second;

Expand All @@ -101,11 +109,10 @@ void gtsam::IncrementalHybrid::update(gtsam::GaussianHybridFactorGraph graph,
}
}

GaussianMixture::Factors prunedConditionalsTree(
lastDensity->discreteKeys(),
prunedConditionals
);
GaussianMixture::Factors prunedConditionalsTree(lastDensity->discreteKeys(),
prunedConditionals);

hybridBayesNet_->atGaussian(hybridBayesNet_->size() - 1)->factors_ = prunedConditionalsTree;
hybridBayesNet_->atGaussian(hybridBayesNet_->size() - 1)->factors_ =
prunedConditionalsTree;
}
}
32 changes: 15 additions & 17 deletions gtsam/hybrid/tests/testIncrementalHybrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
* @date Jan 2021
*/

#include "Switching.h"

#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/hybrid/DCFactor.h>
Expand All @@ -30,6 +28,8 @@

#include <numeric>

#include "Switching.h"

// Include for test suite
#include <CppUnitLite/TestHarness.h>

Expand Down Expand Up @@ -91,11 +91,11 @@ TEST_UNSAFE(DCGaussianElimination, Incremental_inference) {

auto hybridBayesNet2 = incrementalHybrid.hybridBayesNet_;
CHECK(hybridBayesNet2);
EXPECT_LONGS_EQUAL(2, hybridBayesNet2->size());
EXPECT(hybridBayesNet2->at(0)->frontals() == KeyVector{X(2)});
EXPECT(hybridBayesNet2->at(0)->parents() == KeyVector({X(3), M(2), M(1)}));
EXPECT(hybridBayesNet2->at(1)->frontals() == KeyVector{X(3)});
EXPECT(hybridBayesNet2->at(1)->parents() == KeyVector({M(2), M(1)}));
EXPECT_LONGS_EQUAL(4, hybridBayesNet2->size());
EXPECT(hybridBayesNet2->at(2)->frontals() == KeyVector{X(2)});
EXPECT(hybridBayesNet2->at(2)->parents() == KeyVector({X(3), M(2), M(1)}));
EXPECT(hybridBayesNet2->at(3)->frontals() == KeyVector{X(3)});
EXPECT(hybridBayesNet2->at(3)->parents() == KeyVector({M(2), M(1)}));

auto remainingFactorGraph2 = incrementalHybrid.remainingFactorGraph_;
CHECK(remainingFactorGraph2);
Expand All @@ -117,16 +117,15 @@ TEST_UNSAFE(DCGaussianElimination, Incremental_inference) {
switching.linearizedFactorGraph.eliminatePartialSequential(ordering);

// The densities on X(1) should be the same
EXPECT(
assert_equal(*(hybridBayesNet->atGaussian(0)),
*(expectedHybridBayesNet->atGaussian(0))));
EXPECT(assert_equal(*(hybridBayesNet->atGaussian(0)),
*(expectedHybridBayesNet->atGaussian(0))));

// The densities on X(2) should be the same
EXPECT(assert_equal(*(hybridBayesNet2->atGaussian(0)),
EXPECT(assert_equal(*(hybridBayesNet2->atGaussian(2)),
*(expectedHybridBayesNet->atGaussian(1))));

// The densities on X(3) should be the same
EXPECT(assert_equal(*(hybridBayesNet2->atGaussian(1)),
EXPECT(assert_equal(*(hybridBayesNet2->atGaussian(3)),
*(expectedHybridBayesNet->atGaussian(2))));

// we only do the manual continuous elimination for 0,0
Expand Down Expand Up @@ -236,7 +235,7 @@ TEST(DCGaussianElimination, Approx_inference) {
EXPECT(discreteFactor_m1.keys() == KeyVector({M(3), M(2), M(1)}));

// Check number of elements equal to zero
auto count = [](const double& value, int count) {
auto count = [](const double &value, int count) {
return value > 0 ? count + 1 : count;
};
EXPECT_LONGS_EQUAL(5, discreteFactor_m1.fold(count, 0));
Expand All @@ -246,7 +245,7 @@ TEST(DCGaussianElimination, Approx_inference) {
* factor 1: [x2 | x3 m2 m1 ], 4 components
* factor 2: [x3 | x4 m3 m2 m1 ], 8 components
* factor 3: [x4 | m3 m2 m1 ], 8 components
*/
*/
auto hybridBayesNet = incrementalHybrid.hybridBayesNet_;

CHECK(hybridBayesNet);
Expand All @@ -258,8 +257,8 @@ TEST(DCGaussianElimination, Approx_inference) {

auto &lastDensity = *(hybridBayesNet->atGaussian(3));
auto &unprunedLastDensity = *(unprunedHybridBayesNet->atGaussian(3));
std::vector<std::pair<DiscreteValues, double>>
assignments = discreteFactor_m1.enumerate();
std::vector<std::pair<DiscreteValues, double>> assignments =
discreteFactor_m1.enumerate();
// Loop over all assignments and check the pruned components
for (auto &&av : assignments) {
const DiscreteValues &assignment = av.first;
Expand Down Expand Up @@ -326,7 +325,6 @@ TEST_UNSAFE(DCGaussianElimination, Incremental_approximate) {
EXPECT_LONGS_EQUAL(5, actualBayesNet.atGaussian(1)->nrComponents());
}


/* ************************************************************************* */
int main() {
TestResult tr;
Expand Down

0 comments on commit ae02498

Please sign in to comment.