Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid Updates #1306

Merged
merged 28 commits into from
Oct 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
b4c70f2
add code for simplified hybrid estimation
varunagrawal Sep 4, 2022
b2ca747
almost done with single legged robot
varunagrawal Sep 4, 2022
8f94f72
single leg robot test
varunagrawal Sep 4, 2022
b5eaaab
add more tests to show scheme doesn't work
varunagrawal Sep 4, 2022
858193d
remove print statements
varunagrawal Sep 11, 2022
269d60e
Merge branch 'develop' into varun/test-hybrid-estimation
varunagrawal Oct 3, 2022
bf2015d
Merge branch 'hybrid/check-elimination' into varun/test-hybrid-estima…
varunagrawal Oct 3, 2022
75b2599
remove unnecessary comments
varunagrawal Oct 3, 2022
42e915f
Merge branch 'hybrid/improved-prune-2' into varun/test-hybrid-estimation
varunagrawal Oct 3, 2022
812ae30
Merge branch 'hybrid/gaussian-conditional' into varun/test-hybrid-est…
varunagrawal Oct 4, 2022
6bd16d9
move GTDKeyFormatter to types.h
varunagrawal Oct 7, 2022
acdc1af
Merge branch 'develop' into varun/test-hybrid-estimation
varunagrawal Oct 8, 2022
9c77c66
fix tests
varunagrawal Oct 8, 2022
1a17a81
formatting
varunagrawal Oct 9, 2022
055df81
apply keyformatter to Gaussian iSAM in HybridNonlinearISAM
varunagrawal Oct 9, 2022
3e15184
slight improvement to GaussianMixtureFactor print
varunagrawal Oct 10, 2022
a00bcbc
PrunerFunc helper function
varunagrawal Oct 10, 2022
2c8fe25
enumerate missing discrete keys so we can prune all gaussian mixtures
varunagrawal Oct 10, 2022
c15cfb6
add PrunerFunc to GaussianMixture
varunagrawal Oct 11, 2022
2225ecf
clean up the prunerFunc
varunagrawal Oct 11, 2022
5e99cd7
HybridBayesNet and HybridBayesTree both use similar pruning functions
varunagrawal Oct 11, 2022
c2377f3
minor fixes to unit test
varunagrawal Oct 11, 2022
8090812
add assertions to remove warning
varunagrawal Oct 11, 2022
eacc888
remove print function
varunagrawal Oct 11, 2022
8c19f49
add curlyy brackets to for loop
varunagrawal Oct 13, 2022
1d70d14
remove custom keyformatter
varunagrawal Oct 13, 2022
8f7473d
remove added test file
varunagrawal Oct 13, 2022
0faf222
remove leftover comment
varunagrawal Oct 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions gtsam/discrete/tests/testDiscreteBayesTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ TEST(DiscreteBayesTree, ThinTree) {
clique->separatorMarginal(EliminateDiscrete);
DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);

DOUBLES_EQUAL(joint_12_14, 0.1875, 1e-9);
DOUBLES_EQUAL(joint_8_12_14, 0.0375, 1e-9);
DOUBLES_EQUAL(joint_9_12_14, 0.15, 1e-9);

// check separator marginal P(S9), should be P(14)
clique = (*self.bayesTree)[9];
DiscreteFactorGraph separatorMarginal9 =
Expand Down
76 changes: 66 additions & 10 deletions gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,25 +128,81 @@ void GaussianMixture::print(const std::string &s,
});
}

/* *******************************************************************************/
void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
// Functional which loops over all assignments and create a set of
// GaussianConditionals
auto pruner = [&decisionTree](
/* ************************************************************************* */
/// Return the DiscreteKey vector as a set.
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
std::set<DiscreteKey> s;
s.insert(dkeys.begin(), dkeys.end());
return s;
}

/* ************************************************************************* */
/**
* @brief Helper function to get the pruner functional.
*
* @param decisionTree The probability decision tree of only discrete keys.
* @return std::function<GaussianConditional::shared_ptr(
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
*/
std::function<GaussianConditional::shared_ptr(
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
ProfFan marked this conversation as resolved.
Show resolved Hide resolved
// Get the discrete keys as sets for the decision tree
// and the gaussian mixture.
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
auto gaussianMixtureKeySet = DiscreteKeysAsSet(this->discreteKeys());

auto pruner = [decisionTree, decisionTreeKeySet, gaussianMixtureKeySet](
const Assignment<Key> &choices,
const GaussianConditional::shared_ptr &conditional)
-> GaussianConditional::shared_ptr {
// typecast so we can use this to get probability value
DiscreteValues values(choices);

if (decisionTree(values) == 0.0) {
// empty aka null pointer
boost::shared_ptr<GaussianConditional> null;
return null;
// Case where the gaussian mixture has the same
// discrete keys as the decision tree.
if (gaussianMixtureKeySet == decisionTreeKeySet) {
if (decisionTree(values) == 0.0) {
// empty aka null pointer
boost::shared_ptr<GaussianConditional> null;
return null;
} else {
return conditional;
}
} else {
return conditional;
std::vector<DiscreteKey> set_diff;
std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(),
gaussianMixtureKeySet.begin(),
gaussianMixtureKeySet.end(),
std::back_inserter(set_diff));

const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(set_diff);
for (const DiscreteValues &assignment : assignments) {
DiscreteValues augmented_values(values);
augmented_values.insert(assignment.begin(), assignment.end());

// If any one of the sub-branches are non-zero,
// we need this conditional.
if (decisionTree(augmented_values) > 0.0) {
return conditional;
}
}
// If we are here, it means that all the sub-branches are 0,
// so we prune.
return nullptr;
}
};
return pruner;
}

/* *******************************************************************************/
void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
auto gmKeySet = DiscreteKeysAsSet(this->discreteKeys());
// Functional which loops over all assignments and create a set of
// GaussianConditionals
auto pruner = prunerFunc(decisionTree);

auto pruned_conditionals = conditionals_.apply(pruner);
conditionals_.root_ = pruned_conditionals.root_;
Expand Down
11 changes: 11 additions & 0 deletions gtsam/hybrid/GaussianMixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,17 @@ class GTSAM_EXPORT GaussianMixture
*/
Sum asGaussianFactorGraphTree() const;

/**
* @brief Helper function to get the pruner functor.
*
* @param decisionTree The pruned discrete probability decision tree.
* @return std::function<GaussianConditional::shared_ptr(
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
*/
std::function<GaussianConditional::shared_ptr(
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
prunerFunc(const DecisionTreeFactor &decisionTree);

public:
/// @name Constructors
/// @{
Expand Down
9 changes: 5 additions & 4 deletions gtsam/hybrid/GaussianMixtureFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,12 @@ void GaussianMixtureFactor::print(const std::string &s,
[&](const GaussianFactor::shared_ptr &gf) -> std::string {
RedirectCout rd;
std::cout << ":\n";
if (gf)
if (gf && !gf->empty()) {
gf->print("", formatter);
else
return {"nullptr"};
return rd.str();
return rd.str();
} else {
return "nullptr";
}
});
std::cout << "}" << std::endl;
}
Expand Down
69 changes: 9 additions & 60 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,6 @@

namespace gtsam {

/* ************************************************************************* */
/// Return the DiscreteKey vector as a set.
static std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
std::set<DiscreteKey> s;
s.insert(dkeys.begin(), dkeys.end());
return s;
}

/* ************************************************************************* */
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
AlgebraicDecisionTree<Key> decisionTree;
Expand Down Expand Up @@ -66,61 +58,18 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {

HybridBayesNet prunedBayesNetFragment;

// Functional which loops over all assignments and create a set of
// GaussianConditionals
auto pruner = [&](const Assignment<Key> &choices,
const GaussianConditional::shared_ptr &conditional)
-> GaussianConditional::shared_ptr {
// typecast so we can use this to get probability value
DiscreteValues values(choices);

if ((*discreteFactor)(values) == 0.0) {
// empty aka null pointer
boost::shared_ptr<GaussianConditional> null;
return null;
} else {
return conditional;
}
};

// Go through all the conditionals in the
// Bayes Net and prune them as per discreteFactor.
for (size_t i = 0; i < this->size(); i++) {
HybridConditional::shared_ptr conditional = this->at(i);

GaussianMixture::shared_ptr gaussianMixture =
boost::dynamic_pointer_cast<GaussianMixture>(conditional->inner());

if (gaussianMixture) {
// We may have mixtures with less discrete keys than discreteFactor so we
// skip those since the label assignment does not exist.
auto gmKeySet = DiscreteKeysAsSet(gaussianMixture->discreteKeys());
auto dfKeySet = DiscreteKeysAsSet(discreteFactor->discreteKeys());
if (gmKeySet != dfKeySet) {
// Add the gaussianMixture which doesn't have to be pruned.
prunedBayesNetFragment.push_back(
boost::make_shared<HybridConditional>(gaussianMixture));
continue;
}

// Run the pruning to get a new, pruned tree
GaussianMixture::Conditionals prunedTree =
gaussianMixture->conditionals().apply(pruner);

DiscreteKeys discreteKeys = gaussianMixture->discreteKeys();
// reverse keys to get a natural ordering
std::reverse(discreteKeys.begin(), discreteKeys.end());

// Convert from boost::iterator_range to KeyVector
// so we can pass it to constructor.
KeyVector frontals(gaussianMixture->frontals().begin(),
gaussianMixture->frontals().end()),
parents(gaussianMixture->parents().begin(),
gaussianMixture->parents().end());

// Create the new gaussian mixture and add it to the bayes net.
auto prunedGaussianMixture = boost::make_shared<GaussianMixture>(
frontals, parents, discreteKeys, prunedTree);
if (conditional->isHybrid()) {
GaussianMixture::shared_ptr gaussianMixture = conditional->asMixture();

// Make a copy of the gaussian mixture and prune it!
auto prunedGaussianMixture =
boost::make_shared<GaussianMixture>(*gaussianMixture);
prunedGaussianMixture->prune(*discreteFactor);

// Type-erase and add to the pruned Bayes Net fragment.
prunedBayesNetFragment.push_back(
Expand Down Expand Up @@ -173,7 +122,7 @@ GaussianBayesNet HybridBayesNet::choose(
return gbn;
}

/* *******************************************************************************/
/* ************************************************************************* */
HybridValues HybridBayesNet::optimize() const {
// Solve for the MPE
DiscreteBayesNet discrete_bn;
Expand All @@ -190,7 +139,7 @@ HybridValues HybridBayesNet::optimize() const {
return HybridValues(mpe, gbn.optimize());
}

/* *******************************************************************************/
/* ************************************************************************* */
VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
GaussianBayesNet gbn = this->choose(assignment);
return gbn.optimize();
Expand Down
20 changes: 7 additions & 13 deletions gtsam/hybrid/HybridBayesTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,16 +149,16 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
auto decisionTree = boost::dynamic_pointer_cast<DecisionTreeFactor>(
this->roots_.at(0)->conditional()->inner());

DecisionTreeFactor prunedDiscreteFactor = decisionTree->prune(maxNrLeaves);
decisionTree->root_ = prunedDiscreteFactor.root_;
DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves);
decisionTree->root_ = prunedDecisionTree.root_;

/// Helper struct for pruning the hybrid bayes tree.
struct HybridPrunerData {
/// The discrete decision tree after pruning.
DecisionTreeFactor prunedDiscreteFactor;
HybridPrunerData(const DecisionTreeFactor& prunedDiscreteFactor,
DecisionTreeFactor prunedDecisionTree;
HybridPrunerData(const DecisionTreeFactor& prunedDecisionTree,
const HybridBayesTree::sharedNode& parentClique)
: prunedDiscreteFactor(prunedDiscreteFactor) {}
: prunedDecisionTree(prunedDecisionTree) {}

/**
* @brief A function used during tree traversal that operates on each node
Expand All @@ -178,19 +178,13 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
if (conditional->isHybrid()) {
auto gaussianMixture = conditional->asMixture();

// Check if the number of discrete keys match,
// else we get an assignment error.
// TODO(Varun) Update prune method to handle assignment subset?
if (gaussianMixture->discreteKeys() ==
parentData.prunedDiscreteFactor.discreteKeys()) {
gaussianMixture->prune(parentData.prunedDiscreteFactor);
}
gaussianMixture->prune(parentData.prunedDecisionTree);
}
return parentData;
}
};

HybridPrunerData rootData(prunedDiscreteFactor, 0);
HybridPrunerData rootData(prunedDecisionTree, 0);
{
treeTraversal::no_op visitorPost;
// Limits OpenMP threads since we're mixing TBB and OpenMP
Expand Down
5 changes: 1 addition & 4 deletions gtsam/hybrid/HybridFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,7 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1,

/* ************************************************************************ */
HybridFactor::HybridFactor(const KeyVector &keys)
: Base(keys),
isContinuous_(true),
continuousKeys_(keys) {}
: Base(keys), isContinuous_(true), continuousKeys_(keys) {}

/* ************************************************************************ */
HybridFactor::HybridFactor(const KeyVector &continuousKeys,
Expand Down Expand Up @@ -101,7 +99,6 @@ void HybridFactor::print(const std::string &s,
if (d < discreteKeys_.size() - 1) {
std::cout << " ";
}

}
std::cout << "]";
}
Expand Down
6 changes: 3 additions & 3 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,10 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
result = EliminatePreferCholesky(graph, frontalKeys);

if (keysOfEliminated.empty()) {
keysOfEliminated =
result.first->keys(); // Initialize the keysOfEliminated to be the
// Initialize the keysOfEliminated to be the keys of the
// eliminated GaussianConditional
keysOfEliminated = result.first->keys();
}
// keysOfEliminated of the GaussianConditional
if (keysOfSeparator.empty()) {
keysOfSeparator = result.second->keys();
}
Expand Down
5 changes: 3 additions & 2 deletions gtsam/hybrid/HybridGaussianISAM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ void HybridGaussianISAM::updateInternal(
factors += newFactors;

// Add the orphaned subtrees
for (const sharedClique& orphan : *orphans)
factors += boost::make_shared<BayesTreeOrphanWrapper<Node> >(orphan);
for (const sharedClique& orphan : *orphans) {
factors += boost::make_shared<BayesTreeOrphanWrapper<Node>>(orphan);
}

// Get all the discrete keys from the factors
KeySet allDiscrete = factors.discreteKeys();
Expand Down
2 changes: 1 addition & 1 deletion gtsam/hybrid/HybridNonlinearISAM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ void HybridNonlinearISAM::print(const string& s,
const KeyFormatter& keyFormatter) const {
cout << s << "ReorderInterval: " << reorderInterval_
<< " Current Count: " << reorderCounter_ << endl;
isam_.print("HybridGaussianISAM:\n");
isam_.print("HybridGaussianISAM:\n", keyFormatter);
linPoint_.print("Linearization Point:\n", keyFormatter);
factors_.print("Nonlinear Graph:\n", keyFormatter);
}
Expand Down
2 changes: 1 addition & 1 deletion gtsam/hybrid/tests/testHybridGaussianISAM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ TEST(HybridGaussianElimination, Incremental_approximate) {
EXPECT_LONGS_EQUAL(
2, incrementalHybrid[X(1)]->conditional()->asMixture()->nrComponents());
EXPECT_LONGS_EQUAL(
4, incrementalHybrid[X(2)]->conditional()->asMixture()->nrComponents());
3, incrementalHybrid[X(2)]->conditional()->asMixture()->nrComponents());
EXPECT_LONGS_EQUAL(
5, incrementalHybrid[X(3)]->conditional()->asMixture()->nrComponents());
EXPECT_LONGS_EQUAL(
Expand Down
6 changes: 3 additions & 3 deletions gtsam/hybrid/tests/testHybridNonlinearISAM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
/**
* @file testHybridNonlinearISAM.cpp
* @brief Unit tests for nonlinear incremental inference
* @author Fan Jiang, Varun Agrawal, Frank Dellaert
* @author Varun Agrawal, Fan Jiang, Frank Dellaert
* @date Jan 2021
*/

Expand Down Expand Up @@ -363,7 +363,7 @@ TEST(HybridNonlinearISAM, Incremental_approximate) {
EXPECT_LONGS_EQUAL(
2, bayesTree[X(1)]->conditional()->asMixture()->nrComponents());
EXPECT_LONGS_EQUAL(
4, bayesTree[X(2)]->conditional()->asMixture()->nrComponents());
3, bayesTree[X(2)]->conditional()->asMixture()->nrComponents());
EXPECT_LONGS_EQUAL(
5, bayesTree[X(3)]->conditional()->asMixture()->nrComponents());
EXPECT_LONGS_EQUAL(
Expand Down Expand Up @@ -432,9 +432,9 @@ TEST(HybridNonlinearISAM, NonTrivial) {

// Don't run update now since we don't have discrete variables involved.

/*************** Run Round 2 ***************/
using PlanarMotionModel = BetweenFactor<Pose2>;

/*************** Run Round 2 ***************/
// Add odometry factor with discrete modes.
Pose2 odometry(1.0, 0.0, 0.0);
KeyVector contKeys = {W(0), W(1)};
Expand Down