Skip to content

Commit

Permalink
Merge pull request #1372 from borglab/hybrid/simplifiedAPI
Browse files Browse the repository at this point in the history
Simplified AP for HybridBayesNet
  • Loading branch information
dellaert authored Jan 6, 2023
2 parents c24e975 + d49bcce commit a3b177c
Show file tree
Hide file tree
Showing 14 changed files with 119 additions and 171 deletions.
12 changes: 10 additions & 2 deletions gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,21 @@ const GaussianMixture::Conditionals &GaussianMixture::conditionals() const {
return conditionals_;
}

/* *******************************************************************************/
GaussianMixture::GaussianMixture(
KeyVector &&continuousFrontals, KeyVector &&continuousParents,
DiscreteKeys &&discreteParents,
std::vector<GaussianConditional::shared_ptr> &&conditionals)
: GaussianMixture(continuousFrontals, continuousParents, discreteParents,
Conditionals(discreteParents, conditionals)) {}

/* *******************************************************************************/
GaussianMixture::GaussianMixture(
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
const DiscreteKeys &discreteParents,
const std::vector<GaussianConditional::shared_ptr> &conditionalsList)
const std::vector<GaussianConditional::shared_ptr> &conditionals)
: GaussianMixture(continuousFrontals, continuousParents, discreteParents,
Conditionals(discreteParents, conditionalsList)) {}
Conditionals(discreteParents, conditionals)) {}

/* *******************************************************************************/
GaussianFactorGraphTree GaussianMixture::add(
Expand Down
12 changes: 12 additions & 0 deletions gtsam/hybrid/GaussianMixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,18 @@ class GTSAM_EXPORT GaussianMixture
const DiscreteKeys &discreteParents,
const Conditionals &conditionals);

/**
* @brief Make a Gaussian Mixture from a list of Gaussian conditionals
*
* @param continuousFrontals The continuous frontal variables
* @param continuousParents The continuous parent variables
* @param discreteParents Discrete parents variables
* @param conditionals List of conditionals
*/
GaussianMixture(KeyVector &&continuousFrontals, KeyVector &&continuousParents,
DiscreteKeys &&discreteParents,
std::vector<GaussianConditional::shared_ptr> &&conditionals);

/**
* @brief Make a Gaussian Mixture from a list of Gaussian conditionals
*
Expand Down
18 changes: 1 addition & 17 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
prunedGaussianMixture->prune(*decisionTree); // imperative :-(

// Type-erase and add to the pruned Bayes Net fragment.
prunedBayesNetFragment.push_back(
boost::make_shared<HybridConditional>(prunedGaussianMixture));
prunedBayesNetFragment.push_back(prunedGaussianMixture);

} else {
// Add the non-GaussianMixture conditional
Expand All @@ -209,21 +208,6 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
return prunedBayesNetFragment;
}

/* ************************************************************************* */
GaussianMixture::shared_ptr HybridBayesNet::atMixture(size_t i) const {
return at(i)->asMixture();
}

/* ************************************************************************* */
GaussianConditional::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
return at(i)->asGaussian();
}

/* ************************************************************************* */
DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const {
return at(i)->asDiscrete();
}

/* ************************************************************************* */
GaussianBayesNet HybridBayesNet::choose(
const DiscreteValues &assignment) const {
Expand Down
55 changes: 13 additions & 42 deletions gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,55 +63,26 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/// @{

/// Add HybridConditional to Bayes Net
using Base::add;
using Base::emplace_shared;

/// Add a Gaussian Mixture to the Bayes Net.
void addMixture(const GaussianMixture::shared_ptr &ptr) {
push_back(HybridConditional(ptr));
/// Add a conditional directly using a pointer.
template <class Conditional>
void emplace_back(Conditional *conditional) {
factors_.push_back(boost::make_shared<HybridConditional>(
boost::shared_ptr<Conditional>(conditional)));
}

/// Add a Gaussian conditional to the Bayes Net.
void addGaussian(const GaussianConditional::shared_ptr &ptr) {
push_back(HybridConditional(ptr));
/// Add a conditional directly using a shared_ptr.
void push_back(boost::shared_ptr<HybridConditional> conditional) {
factors_.push_back(conditional);
}

/// Add a discrete conditional to the Bayes Net.
void addDiscrete(const DiscreteConditional::shared_ptr &ptr) {
push_back(HybridConditional(ptr));
/// Add a conditional directly using implicit conversion.
void push_back(HybridConditional &&conditional) {
factors_.push_back(
boost::make_shared<HybridConditional>(std::move(conditional)));
}

/// Add a Gaussian Mixture to the Bayes Net.
template <typename... T>
void emplaceMixture(T &&...args) {
push_back(HybridConditional(
boost::make_shared<GaussianMixture>(std::forward<T>(args)...)));
}

/// Add a Gaussian conditional to the Bayes Net.
template <typename... T>
void emplaceGaussian(T &&...args) {
push_back(HybridConditional(
boost::make_shared<GaussianConditional>(std::forward<T>(args)...)));
}

/// Add a discrete conditional to the Bayes Net.
template <typename... T>
void emplaceDiscrete(T &&...args) {
push_back(HybridConditional(
boost::make_shared<DiscreteConditional>(std::forward<T>(args)...)));
}

using Base::push_back;

/// Get a specific Gaussian mixture by index `i`.
GaussianMixture::shared_ptr atMixture(size_t i) const;

/// Get a specific Gaussian conditional by index `i`.
GaussianConditional::shared_ptr atGaussian(size_t i) const;

/// Get a specific discrete conditional by index `i`.
DiscreteConditional::shared_ptr atDiscrete(size_t i) const;

/**
* @brief Get the Gaussian Bayes Net which corresponds to a specific discrete
* value assignment.
Expand Down
6 changes: 3 additions & 3 deletions gtsam/hybrid/HybridConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,23 @@ HybridConditional::HybridConditional(const KeyVector &continuousFrontals,

/* ************************************************************************ */
HybridConditional::HybridConditional(
boost::shared_ptr<GaussianConditional> continuousConditional)
const boost::shared_ptr<GaussianConditional> &continuousConditional)
: HybridConditional(continuousConditional->keys(), {},
continuousConditional->nrFrontals()) {
inner_ = continuousConditional;
}

/* ************************************************************************ */
HybridConditional::HybridConditional(
boost::shared_ptr<DiscreteConditional> discreteConditional)
const boost::shared_ptr<DiscreteConditional> &discreteConditional)
: HybridConditional({}, discreteConditional->discreteKeys(),
discreteConditional->nrFrontals()) {
inner_ = discreteConditional;
}

/* ************************************************************************ */
HybridConditional::HybridConditional(
boost::shared_ptr<GaussianMixture> gaussianMixture)
const boost::shared_ptr<GaussianMixture> &gaussianMixture)
: BaseFactor(KeyVector(gaussianMixture->keys().begin(),
gaussianMixture->keys().begin() +
gaussianMixture->nrContinuous()),
Expand Down
7 changes: 4 additions & 3 deletions gtsam/hybrid/HybridConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,23 +111,24 @@ class GTSAM_EXPORT HybridConditional
* HybridConditional.
*/
HybridConditional(
boost::shared_ptr<GaussianConditional> continuousConditional);
const boost::shared_ptr<GaussianConditional>& continuousConditional);

/**
* @brief Construct a new Hybrid Conditional object
*
* @param discreteConditional Conditional used to create the
* HybridConditional.
*/
HybridConditional(boost::shared_ptr<DiscreteConditional> discreteConditional);
HybridConditional(
const boost::shared_ptr<DiscreteConditional>& discreteConditional);

/**
* @brief Construct a new Hybrid Conditional object
*
* @param gaussianMixture Gaussian Mixture Conditional used to create the
* HybridConditional.
*/
HybridConditional(boost::shared_ptr<GaussianMixture> gaussianMixture);
HybridConditional(const boost::shared_ptr<GaussianMixture>& gaussianMixture);

/// @}
/// @name Testable
Expand Down
4 changes: 2 additions & 2 deletions gtsam/hybrid/HybridSmoother.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ void HybridSmoother::update(HybridGaussianFactorGraph graph,
}

// Add the partial bayes net to the posterior bayes net.
hybridBayesNet_.push_back<HybridBayesNet>(*bayesNetFragment);
hybridBayesNet_.add(*bayesNetFragment);
}

/* ************************************************************************* */
Expand Down Expand Up @@ -100,7 +100,7 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
/* ************************************************************************* */
GaussianMixture::shared_ptr HybridSmoother::gaussianMixture(
size_t index) const {
return hybridBayesNet_.atMixture(index);
return hybridBayesNet_.at(index)->asMixture();
}

/* ************************************************************************* */
Expand Down
26 changes: 3 additions & 23 deletions gtsam/hybrid/hybrid.i
Original file line number Diff line number Diff line change
Expand Up @@ -135,29 +135,9 @@ class HybridBayesTree {
#include <gtsam/hybrid/HybridBayesNet.h>
class HybridBayesNet {
HybridBayesNet();
void add(const gtsam::HybridConditional& s);
void addMixture(const gtsam::GaussianMixture* s);
void addGaussian(const gtsam::GaussianConditional* s);
void addDiscrete(const gtsam::DiscreteConditional* s);

void emplaceMixture(const gtsam::GaussianMixture& s);
void emplaceMixture(const gtsam::KeyVector& continuousFrontals,
const gtsam::KeyVector& continuousParents,
const gtsam::DiscreteKeys& discreteParents,
const std::vector<gtsam::GaussianConditional::shared_ptr>&
conditionalsList);
void emplaceGaussian(const gtsam::GaussianConditional& s);
void emplaceDiscrete(const gtsam::DiscreteConditional& s);
void emplaceDiscrete(const gtsam::DiscreteKey& key, string spec);
void emplaceDiscrete(const gtsam::DiscreteKey& key,
const gtsam::DiscreteKeys& parents, string spec);
void emplaceDiscrete(const gtsam::DiscreteKey& key,
const std::vector<gtsam::DiscreteKey>& parents,
string spec);

gtsam::GaussianMixture* atMixture(size_t i) const;
gtsam::GaussianConditional* atGaussian(size_t i) const;
gtsam::DiscreteConditional* atDiscrete(size_t i) const;
void push_back(const gtsam::GaussianMixture* s);
void push_back(const gtsam::GaussianConditional* s);
void push_back(const gtsam::DiscreteConditional* s);

bool empty() const;
size_t size() const;
Expand Down
16 changes: 8 additions & 8 deletions gtsam/hybrid/tests/TinyHybridExample.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,22 +43,22 @@ inline HybridBayesNet createHybridBayesNet(int num_measurements = 1,
// Create Gaussian mixture z_i = x0 + noise for each measurement.
for (int i = 0; i < num_measurements; i++) {
const auto mode_i = manyModes ? DiscreteKey{M(i), 2} : mode;
GaussianMixture gm({Z(i)}, {X(0)}, {mode_i},
{GaussianConditional::sharedMeanAndStddev(
Z(i), I_1x1, X(0), Z_1x1, 0.5),
GaussianConditional::sharedMeanAndStddev(
Z(i), I_1x1, X(0), Z_1x1, 3)});
bayesNet.emplaceMixture(gm); // copy :-(
bayesNet.emplace_back(
new GaussianMixture({Z(i)}, {X(0)}, {mode_i},
{GaussianConditional::sharedMeanAndStddev(
Z(i), I_1x1, X(0), Z_1x1, 0.5),
GaussianConditional::sharedMeanAndStddev(
Z(i), I_1x1, X(0), Z_1x1, 3)}));
}

// Create prior on X(0).
bayesNet.addGaussian(
bayesNet.push_back(
GaussianConditional::sharedMeanAndStddev(X(0), Vector1(5.0), 0.5));

// Add prior on mode.
const size_t nrModes = manyModes ? num_measurements : 1;
for (int i = 0; i < nrModes; i++) {
bayesNet.emplaceDiscrete(DiscreteKey{M(i), 2}, "4/6");
bayesNet.emplace_back(new DiscreteConditional({M(i), 2}, "4/6"));
}
return bayesNet;
}
Expand Down
42 changes: 20 additions & 22 deletions gtsam/hybrid/tests/testHybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,29 +42,29 @@ static const DiscreteKey Asia(asiaKey, 2);
// Test creation of a pure discrete Bayes net.
TEST(HybridBayesNet, Creation) {
HybridBayesNet bayesNet;
bayesNet.emplaceDiscrete(Asia, "99/1");
bayesNet.emplace_back(new DiscreteConditional(Asia, "99/1"));

DiscreteConditional expected(Asia, "99/1");
CHECK(bayesNet.atDiscrete(0));
EXPECT(assert_equal(expected, *bayesNet.atDiscrete(0)));
CHECK(bayesNet.at(0)->asDiscrete());
EXPECT(assert_equal(expected, *bayesNet.at(0)->asDiscrete()));
}

/* ****************************************************************************/
// Test adding a Bayes net to another one.
TEST(HybridBayesNet, Add) {
HybridBayesNet bayesNet;
bayesNet.emplaceDiscrete(Asia, "99/1");
bayesNet.emplace_back(new DiscreteConditional(Asia, "99/1"));

HybridBayesNet other;
other.push_back(bayesNet);
other.add(bayesNet);
EXPECT(bayesNet.equals(other));
}

/* ****************************************************************************/
// Test evaluate for a pure discrete Bayes net P(Asia).
TEST(HybridBayesNet, EvaluatePureDiscrete) {
HybridBayesNet bayesNet;
bayesNet.emplaceDiscrete(Asia, "99/1");
bayesNet.emplace_back(new DiscreteConditional(Asia, "99/1"));
HybridValues values;
values.insert(asiaKey, 0);
EXPECT_DOUBLES_EQUAL(0.99, bayesNet.evaluate(values), 1e-9);
Expand All @@ -80,7 +80,7 @@ TEST(HybridBayesNet, Tiny) {
/* ****************************************************************************/
// Test evaluate for a hybrid Bayes net P(X0|X1) P(X1|Asia) P(Asia).
TEST(HybridBayesNet, evaluateHybrid) {
const auto continuousConditional = GaussianConditional::FromMeanAndStddev(
const auto continuousConditional = GaussianConditional::sharedMeanAndStddev(
X(0), 2 * I_1x1, X(1), Vector1(-4.0), 5.0);

const SharedDiagonal model0 = noiseModel::Diagonal::Sigmas(Vector1(2.0)),
Expand All @@ -93,10 +93,11 @@ TEST(HybridBayesNet, evaluateHybrid) {

// Create hybrid Bayes net.
HybridBayesNet bayesNet;
bayesNet.emplaceGaussian(continuousConditional);
GaussianMixture gm({X(1)}, {}, {Asia}, {conditional0, conditional1});
bayesNet.emplaceMixture(gm); // copy :-(
bayesNet.emplaceDiscrete(Asia, "99/1");
bayesNet.push_back(GaussianConditional::sharedMeanAndStddev(
X(0), 2 * I_1x1, X(1), Vector1(-4.0), 5.0));
bayesNet.emplace_back(
new GaussianMixture({X(1)}, {}, {Asia}, {conditional0, conditional1}));
bayesNet.emplace_back(new DiscreteConditional(Asia, "99/1"));

// Create values at which to evaluate.
HybridValues values;
Expand All @@ -105,7 +106,7 @@ TEST(HybridBayesNet, evaluateHybrid) {
values.insert(X(1), Vector1(1));

const double conditionalProbability =
continuousConditional.evaluate(values.continuous());
continuousConditional->evaluate(values.continuous());
const double mixtureProbability = conditional0->evaluate(values.continuous());
EXPECT_DOUBLES_EQUAL(conditionalProbability * mixtureProbability * 0.99,
bayesNet.evaluate(values), 1e-9);
Expand Down Expand Up @@ -135,17 +136,13 @@ TEST(HybridBayesNet, Choose) {

EXPECT_LONGS_EQUAL(4, gbn.size());

EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
hybridBayesNet->atMixture(0)))(assignment),
EXPECT(assert_equal(*(*hybridBayesNet->at(0)->asMixture())(assignment),
*gbn.at(0)));
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
hybridBayesNet->atMixture(1)))(assignment),
EXPECT(assert_equal(*(*hybridBayesNet->at(1)->asMixture())(assignment),
*gbn.at(1)));
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
hybridBayesNet->atMixture(2)))(assignment),
EXPECT(assert_equal(*(*hybridBayesNet->at(2)->asMixture())(assignment),
*gbn.at(2)));
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
hybridBayesNet->atMixture(3)))(assignment),
EXPECT(assert_equal(*(*hybridBayesNet->at(3)->asMixture())(assignment),
*gbn.at(3)));
}

Expand Down Expand Up @@ -247,11 +244,12 @@ TEST(HybridBayesNet, Error) {
double total_error = 0;
for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) {
if (hybridBayesNet->at(idx)->isHybrid()) {
double error = hybridBayesNet->atMixture(idx)->error(
double error = hybridBayesNet->at(idx)->asMixture()->error(
{delta.continuous(), discrete_values});
total_error += error;
} else if (hybridBayesNet->at(idx)->isContinuous()) {
double error = hybridBayesNet->atGaussian(idx)->error(delta.continuous());
double error =
hybridBayesNet->at(idx)->asGaussian()->error(delta.continuous());
total_error += error;
}
}
Expand Down
2 changes: 1 addition & 1 deletion gtsam/hybrid/tests/testHybridEstimation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ TEST(HybridEstimation, Probability) {
for (auto discrete_conditional : *discreteBayesNet) {
bayesNet->add(discrete_conditional);
}
auto discreteConditional = discreteBayesNet->atDiscrete(0);
auto discreteConditional = discreteBayesNet->at(0)->asDiscrete();

HybridValues hybrid_values = bayesNet->optimize();

Expand Down
Loading

0 comments on commit a3b177c

Please sign in to comment.