Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix HybridGaussianFactorGraph error #1358

Merged
merged 18 commits into from
Dec 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/linear/GaussianFactorGraph.h>

Expand Down Expand Up @@ -149,17 +150,19 @@ boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
const DiscreteKeys discreteParentKeys = discreteKeys();
const KeyVector continuousParentKeys = continuousParents();
const GaussianMixtureFactor::Factors likelihoods(
conditionals(), [&](const GaussianConditional::shared_ptr &conditional) {
return conditional->likelihood(frontals);
conditionals_, [&](const GaussianConditional::shared_ptr &conditional) {
return GaussianMixtureFactor::FactorAndConstant{
conditional->likelihood(frontals),
conditional->logNormalizationConstant()};
});
return boost::make_shared<GaussianMixtureFactor>(
continuousParentKeys, discreteParentKeys, likelihoods);
}

/* ************************************************************************* */
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
std::set<DiscreteKey> s;
s.insert(dkeys.begin(), dkeys.end());
s.insert(discreteKeys.begin(), discreteKeys.end());
return s;
}

Expand All @@ -184,7 +187,7 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
const GaussianConditional::shared_ptr &conditional)
-> GaussianConditional::shared_ptr {
// typecast so we can use this to get probability value
DiscreteValues values(choices);
const DiscreteValues values(choices);

// Case where the gaussian mixture has the same
// discrete keys as the decision tree.
Expand Down Expand Up @@ -254,11 +257,10 @@ AlgebraicDecisionTree<Key> GaussianMixture::error(
}

/* *******************************************************************************/
double GaussianMixture::error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
double GaussianMixture::error(const HybridValues &values) const {
// Directly index to get the conditional, no need to build the whole tree.
auto conditional = conditionals_(discreteValues);
return conditional->error(continuousValues);
auto conditional = conditionals_(values.discrete());
return conditional->error(values.continuous());
}

} // namespace gtsam
12 changes: 6 additions & 6 deletions gtsam/hybrid/GaussianMixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
namespace gtsam {

class GaussianMixtureFactor;
class HybridValues;

/**
* @brief A conditional of gaussian mixtures indexed by discrete variables, as
Expand Down Expand Up @@ -87,7 +88,7 @@ class GTSAM_EXPORT GaussianMixture
/// @name Constructors
/// @{

/// Defaut constructor, mainly for serialization.
/// Default constructor, mainly for serialization.
GaussianMixture() = default;

/**
Expand Down Expand Up @@ -135,6 +136,7 @@ class GTSAM_EXPORT GaussianMixture
/// @name Standard API
/// @{

/// @brief Return the conditional Gaussian for the given discrete assignment.
GaussianConditional::shared_ptr operator()(
const DiscreteValues &discreteValues) const;

Expand Down Expand Up @@ -165,12 +167,10 @@ class GTSAM_EXPORT GaussianMixture
* @brief Compute the error of this Gaussian Mixture given the continuous
* values and a discrete assignment.
*
* @param continuousValues Continuous values at which to compute the error.
* @param discreteValues The discrete assignment for a specific mode sequence.
* @param values Continuous values and discrete assignment.
* @return double
*/
double error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const;
double error(const HybridValues &values) const override;

/**
* @brief Prune the decision tree of Gaussian factors as per the discrete
Expand All @@ -193,7 +193,7 @@ class GTSAM_EXPORT GaussianMixture
};

/// Return the DiscreteKey vector as a set.
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys);
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys);

// traits
template <>
Expand Down
49 changes: 27 additions & 22 deletions gtsam/hybrid/GaussianMixtureFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,20 @@
#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/linear/GaussianFactor.h>
#include <gtsam/linear/GaussianFactorGraph.h>

namespace gtsam {

/* *******************************************************************************/
GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys,
const Factors &factors)
: Base(continuousKeys, discreteKeys), factors_(factors) {}
const Mixture &factors)
: Base(continuousKeys, discreteKeys),
factors_(factors, [](const GaussianFactor::shared_ptr &gf) {
return FactorAndConstant{gf, 0.0};
}) {}

/* *******************************************************************************/
bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
Expand All @@ -43,11 +48,11 @@ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {

// Check the base and the factors:
return Base::equals(*e, tol) &&
factors_.equals(e->factors_,
[tol](const GaussianFactor::shared_ptr &f1,
const GaussianFactor::shared_ptr &f2) {
return f1->equals(*f2, tol);
});
factors_.equals(e->factors_, [tol](const FactorAndConstant &f1,
const FactorAndConstant &f2) {
return f1.factor->equals(*(f2.factor), tol) &&
std::abs(f1.constant - f2.constant) < tol;
});
}

/* *******************************************************************************/
Expand All @@ -60,7 +65,8 @@ void GaussianMixtureFactor::print(const std::string &s,
} else {
factors_.print(
"", [&](Key k) { return formatter(k); },
[&](const GaussianFactor::shared_ptr &gf) -> std::string {
[&](const FactorAndConstant &gf_z) -> std::string {
auto gf = gf_z.factor;
RedirectCout rd;
std::cout << ":\n";
if (gf && !gf->empty()) {
Expand All @@ -75,8 +81,10 @@ void GaussianMixtureFactor::print(const std::string &s,
}

/* *******************************************************************************/
const GaussianMixtureFactor::Factors &GaussianMixtureFactor::factors() {
return factors_;
const GaussianMixtureFactor::Mixture GaussianMixtureFactor::factors() const {
return Mixture(factors_, [](const FactorAndConstant &factor_z) {
return factor_z.factor;
});
}

/* *******************************************************************************/
Expand All @@ -95,9 +103,9 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::add(
/* *******************************************************************************/
GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
const {
auto wrap = [](const GaussianFactor::shared_ptr &factor) {
auto wrap = [](const FactorAndConstant &factor_z) {
GaussianFactorGraph result;
result.push_back(factor);
result.push_back(factor_z.factor);
return result;
};
return {factors_, wrap};
Expand All @@ -107,21 +115,18 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
AlgebraicDecisionTree<Key> GaussianMixtureFactor::error(
const VectorValues &continuousValues) const {
// functor to convert from sharedFactor to double error value.
auto errorFunc =
[continuousValues](const GaussianFactor::shared_ptr &factor) {
return factor->error(continuousValues);
};
auto errorFunc = [continuousValues](const FactorAndConstant &factor_z) {
return factor_z.error(continuousValues);
};
DecisionTree<Key, double> errorTree(factors_, errorFunc);
return errorTree;
}

/* *******************************************************************************/
double GaussianMixtureFactor::error(
const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
// Directly index to get the conditional, no need to build the whole tree.
auto factor = factors_(discreteValues);
return factor->error(continuousValues);
double GaussianMixtureFactor::error(const HybridValues &values) const {
const FactorAndConstant factor_z = factors_(values.discrete());
return factor_z.error(values.continuous());
}
/* *******************************************************************************/

} // namespace gtsam
63 changes: 42 additions & 21 deletions gtsam/hybrid/GaussianMixtureFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,15 @@
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/linear/GaussianFactor.h>
#include <gtsam/linear/VectorValues.h>

namespace gtsam {

class GaussianFactorGraph;

// Needed for wrapper.
using GaussianFactorVector = std::vector<gtsam::GaussianFactor::shared_ptr>;
class HybridValues;
class DiscreteValues;
class VectorValues;

/**
* @brief Implementation of a discrete conditional mixture factor.
Expand All @@ -53,9 +51,29 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
using shared_ptr = boost::shared_ptr<This>;

using Sum = DecisionTree<Key, GaussianFactorGraph>;

/// typedef for Decision Tree of Gaussian Factors
using Factors = DecisionTree<Key, GaussianFactor::shared_ptr>;
using sharedFactor = boost::shared_ptr<GaussianFactor>;

/// Gaussian factor and log of normalizing constant.
struct FactorAndConstant {
sharedFactor factor;
double constant;

// Return error with constant correction.
double error(const VectorValues &values) const {
// Note minus sign: constant is log of normalization constant for probabilities.
// Errors is the negative log-likelihood, hence we subtract the constant here.
return factor->error(values) - constant;
}

// Check pointer equality.
bool operator==(const FactorAndConstant &other) const {
return factor == other.factor && constant == other.constant;
}
};

/// typedef for Decision Tree of Gaussian factors and log-constant.
using Factors = DecisionTree<Key, FactorAndConstant>;
using Mixture = DecisionTree<Key, sharedFactor>;

private:
/// Decision tree of Gaussian factors indexed by discrete keys.
Expand All @@ -82,12 +100,17 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
* @param continuousKeys A vector of keys representing continuous variables.
* @param discreteKeys A vector of keys representing discrete variables and
* their cardinalities.
* @param factors The decision tree of Gaussian Factors stored as the mixture
* @param factors The decision tree of Gaussian factors stored as the mixture
* density.
*/
GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys,
const Factors &factors);
const Mixture &factors);

GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys,
const Factors &factors_and_z)
: Base(continuousKeys, discreteKeys), factors_(factors_and_z) {}

/**
* @brief Construct a new GaussianMixtureFactor object using a vector of
Expand All @@ -99,9 +122,9 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
*/
GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys,
const std::vector<GaussianFactor::shared_ptr> &factors)
const std::vector<sharedFactor> &factors)
: GaussianMixtureFactor(continuousKeys, discreteKeys,
Factors(discreteKeys, factors)) {}
Mixture(discreteKeys, factors)) {}

/// @}
/// @name Testable
Expand All @@ -113,9 +136,11 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
const std::string &s = "GaussianMixtureFactor\n",
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
/// @}
/// @name Standard API
/// @{

/// Getter for the underlying Gaussian Factor Decision Tree.
const Factors &factors();
const Mixture factors() const;

/**
* @brief Combine the Gaussian Factor Graphs in `sum` and `this` while
Expand All @@ -137,21 +162,17 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;

/**
* @brief Compute the error of this Gaussian Mixture given the continuous
* values and a discrete assignment.
*
* @param continuousValues Continuous values at which to compute the error.
* @param discreteValues The discrete assignment for a specific mode sequence.
* @brief Compute the log-likelihood, including the log-normalizing constant.
* @return double
*/
double error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const;
double error(const HybridValues &values) const override;

/// Add MixtureFactor to a Sum, syntactic sugar.
friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) {
sum = factor.add(sum);
return sum;
}
/// @}
};

// traits
Expand Down
10 changes: 5 additions & 5 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* GTSAM Copyright 2010-2022, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
Expand All @@ -12,6 +12,7 @@
* @author Fan Jiang
* @author Varun Agrawal
* @author Shangjie Xue
* @author Frank Dellaert
* @date January 2022
*/

Expand Down Expand Up @@ -321,10 +322,9 @@ HybridValues HybridBayesNet::sample() const {
}

/* ************************************************************************* */
double HybridBayesNet::error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
GaussianBayesNet gbn = choose(discreteValues);
return gbn.error(continuousValues);
double HybridBayesNet::error(const HybridValues &values) const {
GaussianBayesNet gbn = choose(values.discrete());
return gbn.error(values.continuous());
}

/* ************************************************************************* */
Expand Down
6 changes: 2 additions & 4 deletions gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @brief 0.5 * sum of squared Mahalanobis distances
* for a specific discrete assignment.
*
* @param continuousValues Continuous values at which to compute the error.
* @param discreteValues Discrete assignment for a specific mode sequence.
* @param values Continuous values and discrete assignment.
* @return double
*/
double error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const;
double error(const HybridValues &values) const;

/**
* @brief Compute conditional error for each discrete assignment,
Expand Down
Loading