Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid Pruning #1309

Merged
merged 11 commits into from
Nov 3, 2022
1 change: 0 additions & 1 deletion gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ void GaussianMixture::print(const std::string &s,
}

/* ************************************************************************* */
/// Return the DiscreteKey vector as a set.
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
std::set<DiscreteKey> s;
s.insert(dkeys.begin(), dkeys.end());
Expand Down
3 changes: 3 additions & 0 deletions gtsam/hybrid/GaussianMixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,9 @@ class GTSAM_EXPORT GaussianMixture
Sum add(const Sum &sum) const;
};

/// Return the DiscreteKey vector as a set.
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys);

// traits
template <>
struct traits<GaussianMixture> : public Testable<GaussianMixture> {};
Expand Down
95 changes: 91 additions & 4 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,100 @@ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
}

/* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
/**
* @brief Helper function to get the pruner functional.
*
* @param decisionTree The probability decision tree of only discrete keys.
* @return std::function<GaussianConditional::shared_ptr(
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
*/
std::function<double(const Assignment<Key> &, double)> prunerFunc(
const DecisionTreeFactor &decisionTree,
const HybridConditional &conditional) {
// Get the discrete keys as sets for the decision tree
// and the gaussian mixture.
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
auto conditionalKeySet = DiscreteKeysAsSet(conditional.discreteKeys());

auto pruner = [decisionTree, decisionTreeKeySet, conditionalKeySet](
const Assignment<Key> &choices,
double probability) -> double {
// typecast so we can use this to get probability value
DiscreteValues values(choices);
// Case where the gaussian mixture has the same
// discrete keys as the decision tree.
if (conditionalKeySet == decisionTreeKeySet) {
if (decisionTree(values) == 0) {
return 0.0;
} else {
return probability;
}
} else {
std::vector<DiscreteKey> set_diff;
std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(),
conditionalKeySet.begin(), conditionalKeySet.end(),
std::back_inserter(set_diff));

const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(set_diff);
for (const DiscreteValues &assignment : assignments) {
DiscreteValues augmented_values(values);
augmented_values.insert(assignment.begin(), assignment.end());

// If any one of the sub-branches are non-zero,
// we need this probability.
if (decisionTree(augmented_values) > 0.0) {
return probability;
}
}
// If we are here, it means that all the sub-branches are 0,
// so we prune.
return 0.0;
}
};
return pruner;
}

/* ************************************************************************* */
void HybridBayesNet::updateDiscreteConditionals(
const DecisionTreeFactor::shared_ptr &prunedDecisionTree) {
KeyVector prunedTreeKeys = prunedDecisionTree->keys();

for (size_t i = 0; i < this->size(); i++) {
HybridConditional::shared_ptr conditional = this->at(i);
if (conditional->isDiscrete()) {
// std::cout << demangle(typeid(conditional).name()) << std::endl;
auto discrete = conditional->asDiscreteConditional();
KeyVector frontals(discrete->frontals().begin(),
discrete->frontals().end());

// Apply prunerFunc to the underlying AlgebraicDecisionTree
auto discreteTree =
boost::dynamic_pointer_cast<DecisionTreeFactor::ADT>(discrete);
DecisionTreeFactor::ADT prunedDiscreteTree =
discreteTree->apply(prunerFunc(*prunedDecisionTree, *conditional));

// Create the new (hybrid) conditional
auto prunedDiscrete = boost::make_shared<DiscreteLookupTable>(
frontals.size(), conditional->discreteKeys(), prunedDiscreteTree);
conditional = boost::make_shared<HybridConditional>(prunedDiscrete);

// Add it back to the BayesNet
this->at(i) = conditional;
}
}
}

/* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
// Get the decision tree of only the discrete keys
auto discreteConditionals = this->discreteConditionals();
const DecisionTreeFactor::shared_ptr discreteFactor =
const DecisionTreeFactor::shared_ptr decisionTree =
boost::make_shared<DecisionTreeFactor>(
discreteConditionals->prune(maxNrLeaves));

this->updateDiscreteConditionals(decisionTree);

/* To Prune, we visitWith every leaf in the GaussianMixture.
* For each leaf, using the assignment we can check the discrete decision tree
* for 0.0 probability, then just set the leaf to a nullptr.
Expand All @@ -59,7 +146,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
HybridBayesNet prunedBayesNetFragment;

// Go through all the conditionals in the
// Bayes Net and prune them as per discreteFactor.
// Bayes Net and prune them as per decisionTree.
for (size_t i = 0; i < this->size(); i++) {
HybridConditional::shared_ptr conditional = this->at(i);

Expand All @@ -69,7 +156,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
// Make a copy of the gaussian mixture and prune it!
auto prunedGaussianMixture =
boost::make_shared<GaussianMixture>(*gaussianMixture);
prunedGaussianMixture->prune(*discreteFactor);
prunedGaussianMixture->prune(*decisionTree);

// Type-erase and add to the pruned Bayes Net fragment.
prunedBayesNetFragment.push_back(
Expand Down
11 changes: 9 additions & 2 deletions gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
*/
VectorValues optimize(const DiscreteValues &assignment) const;

protected:
/**
* @brief Get all the discrete conditionals as a decision tree factor.
*
Expand All @@ -121,11 +120,19 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {

public:
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
HybridBayesNet prune(size_t maxNrLeaves) const;
HybridBayesNet prune(size_t maxNrLeaves);

/// @}

private:
/**
* @brief Update the discrete conditionals with the pruned versions.
*
* @param prunedDecisionTree
*/
void updateDiscreteConditionals(
const DecisionTreeFactor::shared_ptr &prunedDecisionTree);

/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
Expand Down
4 changes: 2 additions & 2 deletions gtsam/hybrid/HybridBayesTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {

/* ************************************************************************* */
void HybridBayesTree::prune(const size_t maxNrLeaves) {
auto decisionTree = boost::dynamic_pointer_cast<DecisionTreeFactor>(
this->roots_.at(0)->conditional()->inner());
auto decisionTree =
this->roots_.at(0)->conditional()->asDiscreteConditional();

DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves);
decisionTree->root_ = prunedDecisionTree.root_;
Expand Down
112 changes: 112 additions & 0 deletions gtsam/hybrid/HybridSmoother.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/* ----------------------------------------------------------------------------

* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)

* See LICENSE for the license information

* -------------------------------------------------------------------------- */

/**
* @file HybridSmoother.cpp
* @brief An incremental smoother for hybrid factor graphs
* @author Varun Agrawal
* @date October 2022
*/

#include <gtsam/hybrid/HybridSmoother.h>

#include <algorithm>
#include <unordered_set>

namespace gtsam {

/* ************************************************************************* */
void HybridSmoother::update(HybridGaussianFactorGraph graph,
const Ordering &ordering,
boost::optional<size_t> maxNrLeaves) {
// Add the necessary conditionals from the previous timestep(s).
std::tie(graph, hybridBayesNet_) =
addConditionals(graph, hybridBayesNet_, ordering);

// Eliminate.
auto bayesNetFragment = graph.eliminateSequential(ordering);

/// Prune
if (maxNrLeaves) {
// `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in
// all the conditionals with the same keys in bayesNetFragment.
HybridBayesNet prunedBayesNetFragment =
bayesNetFragment->prune(*maxNrLeaves);
// Set the bayes net fragment to the pruned version
bayesNetFragment =
boost::make_shared<HybridBayesNet>(prunedBayesNetFragment);
}

// Add the partial bayes net to the posterior bayes net.
hybridBayesNet_.push_back<HybridBayesNet>(*bayesNetFragment);
}

/* ************************************************************************* */
std::pair<HybridGaussianFactorGraph, HybridBayesNet>
HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
const HybridBayesNet &originalHybridBayesNet,
const Ordering &ordering) const {
HybridGaussianFactorGraph graph(originalGraph);
HybridBayesNet hybridBayesNet(originalHybridBayesNet);

// If we are not at the first iteration, means we have conditionals to add.
if (!hybridBayesNet.empty()) {
// We add all relevant conditional mixtures on the last continuous variable
// in the previous `hybridBayesNet` to the graph

// Conditionals to remove from the bayes net
// since the conditional will be updated.
std::vector<HybridConditional::shared_ptr> conditionals_to_erase;

// New conditionals to add to the graph
gtsam::HybridBayesNet newConditionals;

// NOTE(Varun) Using a for-range loop doesn't work since some of the
// conditionals are invalid pointers
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, I don't think this is possible in a Hybrid Bayes Net?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is! I was able to get a simple example up and running with this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no other explanation why a for-range loop doesn't work, unfortunately...

for (size_t i = 0; i < hybridBayesNet.size(); i++) {
auto conditional = hybridBayesNet.at(i);

for (auto &key : conditional->frontals()) {
if (std::find(ordering.begin(), ordering.end(), key) !=
ordering.end()) {
newConditionals.push_back(conditional);
conditionals_to_erase.push_back(conditional);

break;
}
}
}
// Remove conditionals at the end so we don't affect the order in the
// original bayes net.
for (auto &&conditional : conditionals_to_erase) {
auto it = find(hybridBayesNet.begin(), hybridBayesNet.end(), conditional);
hybridBayesNet.erase(it);
}

graph.push_back(newConditionals);
// newConditionals.print("\n\n\nNew Conditionals to add back");
}
return {graph, hybridBayesNet};
}

/* ************************************************************************* */
GaussianMixture::shared_ptr HybridSmoother::gaussianMixture(
size_t index) const {
return boost::dynamic_pointer_cast<GaussianMixture>(
hybridBayesNet_.at(index));
}

/* ************************************************************************* */
const HybridBayesNet &HybridSmoother::hybridBayesNet() const {
return hybridBayesNet_;
}

} // namespace gtsam
73 changes: 73 additions & 0 deletions gtsam/hybrid/HybridSmoother.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/* ----------------------------------------------------------------------------

* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)

* See LICENSE for the license information

* -------------------------------------------------------------------------- */

/**
* @file HybridSmoother.h
* @brief An incremental smoother for hybrid factor graphs
* @author Varun Agrawal
* @date October 2022
*/

#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>

namespace gtsam {

class HybridSmoother {
private:
HybridBayesNet hybridBayesNet_;
HybridGaussianFactorGraph remainingFactorGraph_;

public:
/**
* Given new factors, perform an incremental update.
* The relevant densities in the `hybridBayesNet` will be added to the input
* graph (fragment), and then eliminated according to the `ordering`
* presented. The remaining factor graph contains Gaussian mixture factors
* that are not connected to the variables in the ordering, or a single
* discrete factor on all discrete keys, plus all discrete factors in the
* original graph.
*
* \note If maxComponents is given, we look at the discrete factor resulting
* from this elimination, and prune it and the Gaussian components
* corresponding to the pruned choices.
*
* @param graph The new factors, should be linear only
* @param ordering The ordering for elimination, only continuous vars are
* allowed
* @param maxNrLeaves The maximum number of leaves in the new discrete factor,
* if applicable
*/
void update(HybridGaussianFactorGraph graph, const Ordering& ordering,
boost::optional<size_t> maxNrLeaves = boost::none);

/**
* @brief Add conditionals from previous timestep as part of liquefication.
*
* @param graph The new factor graph for the current time step.
* @param hybridBayesNet The hybrid bayes net containing all conditionals so
* far.
* @param ordering The elimination ordering.
* @return std::pair<HybridGaussianFactorGraph, HybridBayesNet>
*/
std::pair<HybridGaussianFactorGraph, HybridBayesNet> addConditionals(
const HybridGaussianFactorGraph& graph,
const HybridBayesNet& hybridBayesNet, const Ordering& ordering) const;

/// Get the Gaussian Mixture from the Bayes Net posterior at `index`.
GaussianMixture::shared_ptr gaussianMixture(size_t index) const;

/// Return the Bayes Net posterior.
const HybridBayesNet& hybridBayesNet() const;
};

}; // namespace gtsam
Loading