Skip to content

Commit

Permalink
Merge pull request #1373 from borglab/default-ordering-func
Browse files Browse the repository at this point in the history
  • Loading branch information
dellaert authored Jan 8, 2023
2 parents a3b177c + ec51492 commit 690a230
Show file tree
Hide file tree
Showing 13 changed files with 131 additions and 129 deletions.
12 changes: 10 additions & 2 deletions gtsam/discrete/DiscreteFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,17 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
typedef DiscreteBayesTree BayesTreeType; ///< Type of Bayes tree
typedef DiscreteJunctionTree JunctionTreeType; ///< Type of Junction tree
/// The default dense elimination function
static std::pair<boost::shared_ptr<ConditionalType>, boost::shared_ptr<FactorType> >
static std::pair<boost::shared_ptr<ConditionalType>,
boost::shared_ptr<FactorType> >
DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) {
return EliminateDiscrete(factors, keys); }
return EliminateDiscrete(factors, keys);
}
/// The default ordering generation function
static Ordering DefaultOrderingFunc(
const FactorGraphType& graph,
boost::optional<const VariableIndex&> variableIndex) {
return Ordering::Colamd(*variableIndex);
}
};

/* ************************************************************************* */
Expand Down
30 changes: 15 additions & 15 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,21 @@ namespace gtsam {
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
template class EliminateableFactorGraph<HybridGaussianFactorGraph>;

/* ************************************************************************ */
const Ordering HybridOrdering(const HybridGaussianFactorGraph &graph) {
KeySet discrete_keys = graph.discreteKeys();
for (auto &factor : graph) {
for (const DiscreteKey &k : factor->discreteKeys()) {
discrete_keys.insert(k.first);
}
}

const VariableIndex index(graph);
Ordering ordering = Ordering::ColamdConstrainedLast(
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
return ordering;
}

/* ************************************************************************ */
static GaussianFactorGraphTree addGaussian(
const GaussianFactorGraphTree &gfgTree,
Expand Down Expand Up @@ -448,21 +463,6 @@ void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) {
FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(factor));
}

/* ************************************************************************ */
const Ordering HybridGaussianFactorGraph::getHybridOrdering() const {
KeySet discrete_keys = discreteKeys();
for (auto &factor : factors_) {
for (const DiscreteKey &k : factor->discreteKeys()) {
discrete_keys.insert(k.first);
}
}

const VariableIndex index(factors_);
Ordering ordering = Ordering::ColamdConstrainedLast(
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
return ordering;
}

/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
const VectorValues &continuousValues) const {
Expand Down
23 changes: 15 additions & 8 deletions gtsam/hybrid/HybridGaussianFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ GTSAM_EXPORT
std::pair<boost::shared_ptr<HybridConditional>, HybridFactor::shared_ptr>
EliminateHybrid(const HybridGaussianFactorGraph& factors, const Ordering& keys);

/**
* @brief Return a Colamd constrained ordering where the discrete keys are
* eliminated after the continuous keys.
*
* @return const Ordering
*/
GTSAM_EXPORT const Ordering
HybridOrdering(const HybridGaussianFactorGraph& graph);

/* ************************************************************************* */
template <>
struct EliminationTraits<HybridGaussianFactorGraph> {
Expand All @@ -74,6 +83,12 @@ struct EliminationTraits<HybridGaussianFactorGraph> {
DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) {
return EliminateHybrid(factors, keys);
}
/// The default ordering generation function
static Ordering DefaultOrderingFunc(
const FactorGraphType& graph,
boost::optional<const VariableIndex&> variableIndex) {
return HybridOrdering(graph);
}
};

/**
Expand Down Expand Up @@ -228,14 +243,6 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
*/
double probPrime(const HybridValues& values) const;

/**
* @brief Return a Colamd constrained ordering where the discrete keys are
* eliminated after the continuous keys.
*
* @return const Ordering
*/
const Ordering getHybridOrdering() const;

/**
* @brief Create a decision tree of factor graphs out of this hybrid factor
* graph.
Expand Down
15 changes: 5 additions & 10 deletions gtsam/hybrid/tests/testHybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,8 @@ TEST(HybridBayesNet, OptimizeAssignment) {
TEST(HybridBayesNet, Optimize) {
Switching s(4, 1.0, 0.1, {0, 1, 2, 3}, "1/1 1/1");

Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesNet::shared_ptr hybridBayesNet =
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);
s.linearizedFactorGraph.eliminateSequential();

HybridValues delta = hybridBayesNet->optimize();

Expand All @@ -212,9 +211,8 @@ TEST(HybridBayesNet, Optimize) {
TEST(HybridBayesNet, Error) {
Switching s(3);

Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesNet::shared_ptr hybridBayesNet =
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);
s.linearizedFactorGraph.eliminateSequential();

HybridValues delta = hybridBayesNet->optimize();
auto error_tree = hybridBayesNet->error(delta.continuous());
Expand Down Expand Up @@ -266,9 +264,8 @@ TEST(HybridBayesNet, Error) {
TEST(HybridBayesNet, Prune) {
Switching s(4);

Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesNet::shared_ptr hybridBayesNet =
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);
s.linearizedFactorGraph.eliminateSequential();

HybridValues delta = hybridBayesNet->optimize();

Expand All @@ -284,9 +281,8 @@ TEST(HybridBayesNet, Prune) {
TEST(HybridBayesNet, UpdateDiscreteConditionals) {
Switching s(4);

Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesNet::shared_ptr hybridBayesNet =
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);
s.linearizedFactorGraph.eliminateSequential();

size_t maxNrLeaves = 3;
auto discreteConditionals = hybridBayesNet->discreteConditionals();
Expand Down Expand Up @@ -353,8 +349,7 @@ TEST(HybridBayesNet, Sampling) {
// Create the factor graph from the nonlinear factor graph.
HybridGaussianFactorGraph::shared_ptr fg = nfg.linearize(initial);
// Eliminate into BN
Ordering ordering = fg->getHybridOrdering();
HybridBayesNet::shared_ptr bn = fg->eliminateSequential(ordering);
HybridBayesNet::shared_ptr bn = fg->eliminateSequential();

// Set up sampling
std::mt19937_64 gen(11);
Expand Down
15 changes: 3 additions & 12 deletions gtsam/hybrid/tests/testHybridBayesTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,8 @@ using symbol_shorthand::X;
TEST(HybridBayesTree, OptimizeMultifrontal) {
Switching s(4);

Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesTree::shared_ptr hybridBayesTree =
s.linearizedFactorGraph.eliminateMultifrontal(hybridOrdering);
s.linearizedFactorGraph.eliminateMultifrontal();
HybridValues delta = hybridBayesTree->optimize();

VectorValues expectedValues;
Expand Down Expand Up @@ -203,16 +202,8 @@ TEST(HybridBayesTree, Choose) {

GaussianBayesTree gbt = isam.choose(assignment);

Ordering ordering;
ordering += X(0);
ordering += X(1);
ordering += X(2);
ordering += X(3);
ordering += M(0);
ordering += M(1);
ordering += M(2);

// TODO(Varun) get segfault if ordering not provided
// Specify ordering so it matches that of HybridGaussianISAM.
Ordering ordering(KeyVector{X(0), X(1), X(2), X(3), M(0), M(1), M(2)});
auto bayesTree = s.linearizedFactorGraph.eliminateMultifrontal(ordering);

auto expected_gbt = bayesTree->choose(assignment);
Expand Down
5 changes: 2 additions & 3 deletions gtsam/hybrid/tests/testHybridEstimation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ TEST(HybridEstimation, Full) {
}

HybridBayesNet::shared_ptr bayesNet =
graph.eliminateSequential(hybridOrdering);
graph.eliminateSequential();

EXPECT_LONGS_EQUAL(2 * K - 1, bayesNet->size());

Expand Down Expand Up @@ -481,8 +481,7 @@ TEST(HybridEstimation, CorrectnessViaSampling) {
const auto fg = createHybridGaussianFactorGraph();

// 2. Eliminate into BN
const Ordering ordering = fg->getHybridOrdering();
const HybridBayesNet::shared_ptr bn = fg->eliminateSequential(ordering);
const HybridBayesNet::shared_ptr bn = fg->eliminateSequential();

// Set up sampling
std::mt19937_64 rng(11);
Expand Down
38 changes: 10 additions & 28 deletions gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullSequentialEqualChance) {

hfg.add(GaussianMixtureFactor({X(1)}, {m1}, dt));

auto result =
hfg.eliminateSequential(Ordering::ColamdConstrainedLast(hfg, {M(1)}));
auto result = hfg.eliminateSequential();

auto dc = result->at(2)->asDiscrete();
DiscreteValues dv;
Expand Down Expand Up @@ -161,8 +160,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullSequentialSimple) {
// Joint discrete probability table for c1, c2
hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4"));

HybridBayesNet::shared_ptr result = hfg.eliminateSequential(
Ordering::ColamdConstrainedLast(hfg, {M(1), M(2)}));
HybridBayesNet::shared_ptr result = hfg.eliminateSequential();

// There are 4 variables (2 continuous + 2 discrete) in the bayes net.
EXPECT_LONGS_EQUAL(4, result->size());
Expand All @@ -187,8 +185,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalSimple) {
// variable throws segfault
// hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4"));

HybridBayesTree::shared_ptr result =
hfg.eliminateMultifrontal(hfg.getHybridOrdering());
HybridBayesTree::shared_ptr result = hfg.eliminateMultifrontal();

// The bayes tree should have 3 cliques
EXPECT_LONGS_EQUAL(3, result->size());
Expand Down Expand Up @@ -218,7 +215,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalCLG) {
hfg.add(HybridDiscreteFactor(DecisionTreeFactor(m, {2, 8})));

// Get a constrained ordering keeping c1 last
auto ordering_full = hfg.getHybridOrdering();
auto ordering_full = HybridOrdering(hfg);

// Returns a Hybrid Bayes Tree with distribution P(x0|x1)P(x1|c1)P(c1)
HybridBayesTree::shared_ptr hbt = hfg.eliminateMultifrontal(ordering_full);
Expand Down Expand Up @@ -518,8 +515,7 @@ TEST(HybridGaussianFactorGraph, optimize) {

hfg.add(GaussianMixtureFactor({X(1)}, {c1}, dt));

auto result =
hfg.eliminateSequential(Ordering::ColamdConstrainedLast(hfg, {C(1)}));
auto result = hfg.eliminateSequential();

HybridValues hv = result->optimize();

Expand Down Expand Up @@ -572,9 +568,7 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrime) {

HybridGaussianFactorGraph graph = s.linearizedFactorGraph;

Ordering hybridOrdering = graph.getHybridOrdering();
HybridBayesNet::shared_ptr hybridBayesNet =
graph.eliminateSequential(hybridOrdering);
HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential();

const HybridValues delta = hybridBayesNet->optimize();
const double error = graph.error(delta);
Expand All @@ -593,9 +587,7 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) {

HybridGaussianFactorGraph graph = s.linearizedFactorGraph;

Ordering hybridOrdering = graph.getHybridOrdering();
HybridBayesNet::shared_ptr hybridBayesNet =
graph.eliminateSequential(hybridOrdering);
HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential();

HybridValues delta = hybridBayesNet->optimize();
auto error_tree = graph.error(delta.continuous());
Expand Down Expand Up @@ -684,10 +676,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) {
expectedBayesNet.emplace_back(new DiscreteConditional(mode, "74/26"));

// Test elimination
Ordering ordering;
ordering.push_back(X(0));
ordering.push_back(M(0));
const auto posterior = fg.eliminateSequential(ordering);
const auto posterior = fg.eliminateSequential();
EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01));
}

Expand Down Expand Up @@ -719,10 +708,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny2) {
expectedBayesNet.emplace_back(new DiscreteConditional(mode, "23/77"));

// Test elimination
Ordering ordering;
ordering.push_back(X(0));
ordering.push_back(M(0));
const auto posterior = fg.eliminateSequential(ordering);
const auto posterior = fg.eliminateSequential();
EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01));
}

Expand All @@ -741,11 +727,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny22) {
EXPECT_LONGS_EQUAL(5, fg.size());

// Test elimination
Ordering ordering;
ordering.push_back(X(0));
ordering.push_back(M(0));
ordering.push_back(M(1));
const auto posterior = fg.eliminateSequential(ordering);
const auto posterior = fg.eliminateSequential();

// Compute the log-ratio between the Bayes net and the factor graph.
auto compute_ratio = [&](HybridValues *sample) -> double {
Expand Down
7 changes: 2 additions & 5 deletions gtsam/hybrid/tests/testSerializationHybrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,7 @@ TEST(HybridSerialization, GaussianMixture) {
// Test HybridBayesNet serialization.
TEST(HybridSerialization, HybridBayesNet) {
Switching s(2);
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering));
HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential());

EXPECT(equalsObj<HybridBayesNet>(hbn));
EXPECT(equalsXML<HybridBayesNet>(hbn));
Expand All @@ -162,9 +161,7 @@ TEST(HybridSerialization, HybridBayesNet) {
// Test HybridBayesTree serialization.
TEST(HybridSerialization, HybridBayesTree) {
Switching s(2);
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesTree hbt =
*(s.linearizedFactorGraph.eliminateMultifrontal(ordering));
HybridBayesTree hbt = *(s.linearizedFactorGraph.eliminateMultifrontal());

EXPECT(equalsObj<HybridBayesTree>(hbt));
EXPECT(equalsXML<HybridBayesTree>(hbt));
Expand Down
18 changes: 16 additions & 2 deletions gtsam/inference/EliminateableFactorGraph-inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,16 @@ namespace gtsam {
if (orderingType == Ordering::METIS) {
Ordering computedOrdering = Ordering::Metis(asDerived());
return eliminateSequential(computedOrdering, function, variableIndex);
} else {
} else if (orderingType == Ordering::COLAMD) {
Ordering computedOrdering = Ordering::Colamd(*variableIndex);
return eliminateSequential(computedOrdering, function, variableIndex);
} else if (orderingType == Ordering::NATURAL) {
Ordering computedOrdering = Ordering::Natural(asDerived());
return eliminateSequential(computedOrdering, function, variableIndex);
} else {
Ordering computedOrdering = EliminationTraitsType::DefaultOrderingFunc(
asDerived(), variableIndex);
return eliminateSequential(computedOrdering, function, variableIndex);
}
}
}
Expand Down Expand Up @@ -100,9 +107,16 @@ namespace gtsam {
if (orderingType == Ordering::METIS) {
Ordering computedOrdering = Ordering::Metis(asDerived());
return eliminateMultifrontal(computedOrdering, function, variableIndex);
} else {
} else if (orderingType == Ordering::COLAMD) {
Ordering computedOrdering = Ordering::Colamd(*variableIndex);
return eliminateMultifrontal(computedOrdering, function, variableIndex);
} else if (orderingType == Ordering::NATURAL) {
Ordering computedOrdering = Ordering::Natural(asDerived());
return eliminateMultifrontal(computedOrdering, function, variableIndex);
} else {
Ordering computedOrdering = EliminationTraitsType::DefaultOrderingFunc(
asDerived(), variableIndex);
return eliminateMultifrontal(computedOrdering, function, variableIndex);
}
}
}
Expand Down
6 changes: 6 additions & 0 deletions gtsam/linear/GaussianFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ namespace gtsam {
static std::pair<boost::shared_ptr<ConditionalType>, boost::shared_ptr<FactorType> >
DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) {
return EliminatePreferCholesky(factors, keys); }
/// The default ordering generation function
static Ordering DefaultOrderingFunc(
const FactorGraphType& graph,
boost::optional<const VariableIndex&> variableIndex) {
return Ordering::Colamd(*variableIndex);
}
};

/* ************************************************************************* */
Expand Down
Loading

0 comments on commit 690a230

Please sign in to comment.