Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add various tests for Hybrid #1220

Merged
merged 5 commits into from
Jun 10, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions gtsam/discrete/DecisionTree-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@
#include <string>
#include <vector>

using boost::assign::operator+=;

namespace gtsam {

using boost::assign::operator+=;

/****************************************************************************/
// Node
/****************************************************************************/
Expand Down
41 changes: 30 additions & 11 deletions gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
*/

#include <gtsam/base/utilities.h>
#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/linear/GaussianFactorGraph.h>
Expand All @@ -36,8 +36,7 @@ GaussianMixture::GaussianMixture(
conditionals_(conditionals) {}

/* *******************************************************************************/
const GaussianMixture::Conditionals &
GaussianMixture::conditionals() {
const GaussianMixture::Conditionals &GaussianMixture::conditionals() {
return conditionals_;
}

Expand All @@ -48,8 +47,8 @@ GaussianMixture GaussianMixture::FromConditionals(
const std::vector<GaussianConditional::shared_ptr> &conditionalsList) {
Conditionals dt(discreteParents, conditionalsList);

return GaussianMixture(continuousFrontals, continuousParents,
discreteParents, dt);
return GaussianMixture(continuousFrontals, continuousParents, discreteParents,
dt);
}

/* *******************************************************************************/
Expand All @@ -66,8 +65,7 @@ GaussianMixture::Sum GaussianMixture::add(
}

/* *******************************************************************************/
GaussianMixture::Sum
GaussianMixture::asGaussianFactorGraphTree() const {
GaussianMixture::Sum GaussianMixture::asGaussianFactorGraphTree() const {
auto lambda = [](const GaussianFactor::shared_ptr &factor) {
GaussianFactorGraph result;
result.push_back(factor);
Expand All @@ -77,21 +75,42 @@ GaussianMixture::asGaussianFactorGraphTree() const {
}

/* *******************************************************************************/
bool GaussianMixture::equals(const HybridFactor &lf,
double tol) const {
size_t GaussianMixture::nrComponents() const {
size_t total = 0;
conditionals_.visit([&total](const GaussianFactor::shared_ptr &node) {
if (node) total += 1;
});
return total;
}

/* *******************************************************************************/
GaussianConditional::shared_ptr GaussianMixture::operator()(
const DiscreteValues &discreteVals) const {
auto &ptr = conditionals_(discreteVals);
if (!ptr) return nullptr;
auto conditional = boost::dynamic_pointer_cast<GaussianConditional>(ptr);
if (conditional)
return conditional;
else
throw std::logic_error(
"A GaussianMixture unexpectedly contained a non-conditional");
}

/* *******************************************************************************/
bool GaussianMixture::equals(const HybridFactor &lf, double tol) const {
const This *e = dynamic_cast<const This *>(&lf);
return e != nullptr && BaseFactor::equals(*e, tol);
}

/* *******************************************************************************/
void GaussianMixture::print(const std::string &s,
const KeyFormatter &formatter) const {
const KeyFormatter &formatter) const {
std::cout << s;
if (isContinuous()) std::cout << "Continuous ";
if (isDiscrete()) std::cout << "Discrete ";
if (isHybrid()) std::cout << "Hybrid ";
BaseConditional::print("", formatter);
std::cout << "\nDiscrete Keys = ";
std::cout << " Discrete Keys = ";
for (auto &dk : discreteKeys()) {
std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), ";
}
Expand Down
12 changes: 12 additions & 0 deletions gtsam/hybrid/GaussianMixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

#pragma once

#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/inference/Conditional.h>
#include <gtsam/linear/GaussianConditional.h>
Expand Down Expand Up @@ -99,6 +101,16 @@ class GTSAM_EXPORT GaussianMixture
const DiscreteKeys &discreteParents,
const std::vector<GaussianConditional::shared_ptr> &conditionals);

/// @}
/// @name Standard API
/// @{

GaussianConditional::shared_ptr operator()(
const DiscreteValues &discreteVals) const;

/// Returns the total number of continuous components
size_t nrComponents() const;

/// @}
/// @name Testable
/// @{
Expand Down
7 changes: 5 additions & 2 deletions gtsam/hybrid/GaussianMixtureFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,19 @@ GaussianMixtureFactor GaussianMixtureFactor::FromFactors(
void GaussianMixtureFactor::print(const std::string &s,
const KeyFormatter &formatter) const {
HybridFactor::print(s, formatter);
std::cout << "]{\n";
factors_.print(
"mixture = ", [&](Key k) { return formatter(k); },
"", [&](Key k) { return formatter(k); },
[&](const GaussianFactor::shared_ptr &gf) -> std::string {
RedirectCout rd;
if (!gf->empty())
std::cout << ":\n";
if (gf)
gf->print("", formatter);
else
return {"nullptr"};
return rd.str();
});
std::cout << "}" << std::endl;
}

/* *******************************************************************************/
Expand Down
19 changes: 19 additions & 0 deletions gtsam/hybrid/GaussianMixtureFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,19 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
const DiscreteKeys &discreteKeys,
const Factors &factors);

/**
* @brief Construct a new GaussianMixtureFactor object using a vector of
* GaussianFactor shared pointers.
*
* @param keys Vector of keys for continuous factors.
* @param discreteKeys Vector of discrete keys.
* @param factors Vector of gaussian factor shared pointers.
*/
GaussianMixtureFactor(const KeyVector &keys, const DiscreteKeys &discreteKeys,
const std::vector<GaussianFactor::shared_ptr> &factors)
: GaussianMixtureFactor(keys, discreteKeys,
Factors(discreteKeys, factors)) {}

static This FromFactors(
const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys,
const std::vector<GaussianFactor::shared_ptr> &factors);
Expand Down Expand Up @@ -111,6 +124,12 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
* @return Sum
*/
Sum add(const Sum &sum) const;

/// Add MixtureFactor to a Sum, syntactic sugar.
friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) {
sum = factor.add(sum);
return sum;
}
};

// traits
Expand Down
2 changes: 1 addition & 1 deletion gtsam/hybrid/HybridDiscreteFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ bool HybridDiscreteFactor::equals(const HybridFactor &lf, double tol) const {
void HybridDiscreteFactor::print(const std::string &s,
const KeyFormatter &formatter) const {
HybridFactor::print(s, formatter);
inner_->print("inner: ", formatter);
inner_->print("\n", formatter);
};

} // namespace gtsam
23 changes: 19 additions & 4 deletions gtsam/hybrid/HybridFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1,

/* ************************************************************************ */
HybridFactor::HybridFactor(const KeyVector &keys)
: Base(keys), isContinuous_(true), nrContinuous_(keys.size()) {}
: Base(keys),
isContinuous_(true),
nrContinuous_(keys.size()),
continuousKeys_(keys) {}

/* ************************************************************************ */
HybridFactor::HybridFactor(const KeyVector &continuousKeys,
Expand All @@ -60,13 +63,15 @@ HybridFactor::HybridFactor(const KeyVector &continuousKeys,
isContinuous_((continuousKeys.size() != 0) && (discreteKeys.size() == 0)),
isHybrid_((continuousKeys.size() != 0) && (discreteKeys.size() != 0)),
nrContinuous_(continuousKeys.size()),
discreteKeys_(discreteKeys) {}
discreteKeys_(discreteKeys),
continuousKeys_(continuousKeys) {}

/* ************************************************************************ */
HybridFactor::HybridFactor(const DiscreteKeys &discreteKeys)
: Base(CollectKeys({}, discreteKeys)),
isDiscrete_(true),
discreteKeys_(discreteKeys) {}
discreteKeys_(discreteKeys),
continuousKeys_({}) {}

/* ************************************************************************ */
bool HybridFactor::equals(const HybridFactor &lf, double tol) const {
Expand All @@ -83,7 +88,17 @@ void HybridFactor::print(const std::string &s,
if (isContinuous_) std::cout << "Continuous ";
if (isDiscrete_) std::cout << "Discrete ";
if (isHybrid_) std::cout << "Hybrid ";
this->printKeys("", formatter);
for (size_t c=0; c<continuousKeys_.size(); c++) {
std::cout << formatter(continuousKeys_.at(c));
if (c < continuousKeys_.size() - 1) {
std::cout << " ";
} else {
std::cout << "; ";
}
}
for(auto && discreteKey: discreteKeys_) {
std::cout << formatter(discreteKey.first) << " ";
}
}

} // namespace gtsam
3 changes: 3 additions & 0 deletions gtsam/hybrid/HybridFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ class GTSAM_EXPORT HybridFactor : public Factor {
protected:
DiscreteKeys discreteKeys_;

/// Record continuous keys for book-keeping
KeyVector continuousKeys_;

public:
// typedefs needed to play nice with gtsam
typedef HybridFactor This; ///< This class
Expand Down
2 changes: 1 addition & 1 deletion gtsam/hybrid/HybridGaussianFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ bool HybridGaussianFactor::equals(const HybridFactor &other, double tol) const {
void HybridGaussianFactor::print(const std::string &s,
const KeyFormatter &formatter) const {
HybridFactor::print(s, formatter);
inner_->print("inner: ", formatter);
inner_->print("\n", formatter);
};

} // namespace gtsam
92 changes: 92 additions & 0 deletions gtsam/hybrid/tests/testGaussianMixture.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/* ----------------------------------------------------------------------------

* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)

* See LICENSE for the license information

* -------------------------------------------------------------------------- */

/**
* @file testGaussianMixture.cpp
* @brief Unit tests for GaussianMixture class
* @author Varun Agrawal
* @author Fan Jiang
* @author Frank Dellaert
* @date December 2021
*/

#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/inference/Symbol.h>
#include <gtsam/linear/GaussianConditional.h>

#include <vector>

// Include for test suite
#include <CppUnitLite/TestHarness.h>

using namespace std;
using namespace gtsam;
using noiseModel::Isotropic;
using symbol_shorthand::M;
using symbol_shorthand::X;

/* ************************************************************************* */
TEST(GaussianConditional, Equals) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

document the equations here (2 gaussians)

// create a conditional gaussian node
Matrix S1(2, 2);
S1(0, 0) = 1;
S1(1, 0) = 2;
S1(0, 1) = 3;
S1(1, 1) = 4;

Matrix S2(2, 2);
S2(0, 0) = 6;
S2(1, 0) = 0.2;
S2(0, 1) = 8;
S2(1, 1) = 0.4;

Matrix R1(2, 2);
R1(0, 0) = 0.1;
R1(1, 0) = 0.3;
R1(0, 1) = 0.0;
R1(1, 1) = 0.34;

Matrix R2(2, 2);
R2(0, 0) = 0.1;
R2(1, 0) = 0.3;
R2(0, 1) = 0.0;
R2(1, 1) = 0.34;

SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));

Vector2 d1(0.2, 0.5), d2(0.5, 0.2);

auto conditional0 = boost::make_shared<GaussianConditional>(X(1), d1, R1,
X(2), S1, model),
conditional1 = boost::make_shared<GaussianConditional>(X(1), d2, R2,
X(2), S2, model);

// Create decision tree
DiscreteKey m1(1, 2);
GaussianMixture::Conditionals conditionals(
{m1},
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
GaussianMixture mixtureFactor({X(1)}, {X(2)}, {m1}, conditionals);

// Let's check that this worked:
DiscreteValues mode;
mode[m1.first] = 1;
auto actual = mixtureFactor(mode);
EXPECT(actual == conditional1);
}

/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

newline

Loading