Skip to content

Commit

Permalink
Merge pull request #1318 from borglab/hybrid/error
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Dec 23, 2022
2 parents 7f19e3f + 23ec7ed commit 07d0a03
Show file tree
Hide file tree
Showing 16 changed files with 594 additions and 39 deletions.
8 changes: 4 additions & 4 deletions gtsam/discrete/AlgebraicDecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ namespace gtsam {
static inline double id(const double& x) { return x; }
};

AlgebraicDecisionTree() : Base(1.0) {}
AlgebraicDecisionTree(double leaf = 1.0) : Base(leaf) {}

// Explicitly non-explicit constructor
AlgebraicDecisionTree(const Base& add) : Base(add) {}
Expand Down Expand Up @@ -158,9 +158,9 @@ namespace gtsam {
}

/// print method customized to value type `double`.
void print(const std::string& s,
const typename Base::LabelFormatter& labelFormatter =
&DefaultFormatter) const {
void print(const std::string& s = "",
const typename Base::LabelFormatter& labelFormatter =
&DefaultFormatter) const {
auto valueFormatter = [](const double& v) {
return (boost::format("%4.8g") % v).str();
};
Expand Down
30 changes: 28 additions & 2 deletions gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ size_t GaussianMixture::nrComponents() const {

/* *******************************************************************************/
GaussianConditional::shared_ptr GaussianMixture::operator()(
const DiscreteValues &discreteVals) const {
auto &ptr = conditionals_(discreteVals);
const DiscreteValues &discreteValues) const {
auto &ptr = conditionals_(discreteValues);
if (!ptr) return nullptr;
auto conditional = boost::dynamic_pointer_cast<GaussianConditional>(ptr);
if (conditional)
Expand Down Expand Up @@ -207,4 +207,30 @@ void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
conditionals_.root_ = pruned_conditionals.root_;
}

/* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixture::error(
const VectorValues &continuousValues) const {
// functor to calculate to double error value from GaussianConditional.
auto errorFunc =
[continuousValues](const GaussianConditional::shared_ptr &conditional) {
if (conditional) {
return conditional->error(continuousValues);
} else {
// Return arbitrarily large error if conditional is null
// Conditional is null if it is pruned out.
return 1e50;
}
};
DecisionTree<Key, double> errorTree(conditionals_, errorFunc);
return errorTree;
}

/* *******************************************************************************/
double GaussianMixture::error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
// Directly index to get the conditional, no need to build the whole tree.
auto conditional = conditionals_(discreteValues);
return conditional->error(continuousValues);
}

} // namespace gtsam
22 changes: 21 additions & 1 deletion gtsam/hybrid/GaussianMixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class GTSAM_EXPORT GaussianMixture
/// @{

GaussianConditional::shared_ptr operator()(
const DiscreteValues &discreteVals) const;
const DiscreteValues &discreteValues) const;

/// Returns the total number of continuous components
size_t nrComponents() const;
Expand All @@ -144,6 +144,26 @@ class GTSAM_EXPORT GaussianMixture
/// Getter for the underlying Conditionals DecisionTree
const Conditionals &conditionals();

/**
* @brief Compute error of the GaussianMixture as a tree.
*
* @param continuousValues The continuous VectorValues.
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
* as the conditionals, and leaf values as the error.
*/
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;

/**
* @brief Compute the error of this Gaussian Mixture given the continuous
* values and a discrete assignment.
*
* @param continuousValues Continuous values at which to compute the error.
* @param discreteValues The discrete assignment for a specific mode sequence.
* @return double
*/
double error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const;

/**
* @brief Prune the decision tree of Gaussian factors as per the discrete
* `decisionTree`.
Expand Down
22 changes: 22 additions & 0 deletions gtsam/hybrid/GaussianMixtureFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,26 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
};
return {factors_, wrap};
}

/* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixtureFactor::error(
const VectorValues &continuousValues) const {
// functor to convert from sharedFactor to double error value.
auto errorFunc =
[continuousValues](const GaussianFactor::shared_ptr &factor) {
return factor->error(continuousValues);
};
DecisionTree<Key, double> errorTree(factors_, errorFunc);
return errorTree;
}

/* *******************************************************************************/
double GaussianMixtureFactor::error(
const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
// Directly index to get the conditional, no need to build the whole tree.
auto factor = factors_(discreteValues);
return factor->error(continuousValues);
}

} // namespace gtsam
24 changes: 24 additions & 0 deletions gtsam/hybrid/GaussianMixtureFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,19 @@

#pragma once

#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/linear/GaussianFactor.h>
#include <gtsam/linear/VectorValues.h>

namespace gtsam {

class GaussianFactorGraph;

// Needed for wrapper.
using GaussianFactorVector = std::vector<gtsam::GaussianFactor::shared_ptr>;

/**
Expand Down Expand Up @@ -126,6 +130,26 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
*/
Sum add(const Sum &sum) const;

/**
* @brief Compute error of the GaussianMixtureFactor as a tree.
*
* @param continuousValues The continuous VectorValues.
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
* as the factors involved, and leaf values as the error.
*/
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;

/**
* @brief Compute the error of this Gaussian Mixture given the continuous
* values and a discrete assignment.
*
* @param continuousValues Continuous values at which to compute the error.
* @param discreteValues The discrete assignment for a specific mode sequence.
* @return double
*/
double error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const;

/// Add MixtureFactor to a Sum, syntactic sugar.
friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) {
sum = factor.add(sum);
Expand Down
52 changes: 52 additions & 0 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,4 +232,56 @@ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
return gbn.optimize();
}

/* ************************************************************************* */
double HybridBayesNet::error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
GaussianBayesNet gbn = this->choose(discreteValues);
return gbn.error(continuousValues);
}

/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::error(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree;

// Iterate over each factor.
for (size_t idx = 0; idx < size(); idx++) {
AlgebraicDecisionTree<Key> conditional_error;

if (factors_.at(idx)->isHybrid()) {
// If factor is hybrid, select based on assignment and compute error.
GaussianMixture::shared_ptr gm = this->atMixture(idx);
conditional_error = gm->error(continuousValues);

// Assign for the first index, add error for subsequent ones.
if (idx == 0) {
error_tree = conditional_error;
} else {
error_tree = error_tree + conditional_error;
}

} else if (factors_.at(idx)->isContinuous()) {
// If continuous only, get the (double) error
// and add it to the error_tree
double error = this->atGaussian(idx)->error(continuousValues);
// Add the computed error to every leaf of the error tree.
error_tree = error_tree.apply(
[error](double leaf_value) { return leaf_value + error; });

} else if (factors_.at(idx)->isDiscrete()) {
// If factor at `idx` is discrete-only, we skip.
continue;
}
}

return error_tree;
}

/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::probPrime(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues);
return error_tree.apply([](double error) { return exp(-error); });
}

} // namespace gtsam
33 changes: 33 additions & 0 deletions gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,39 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
HybridBayesNet prune(size_t maxNrLeaves);

/**
* @brief 0.5 * sum of squared Mahalanobis distances
* for a specific discrete assignment.
*
* @param continuousValues Continuous values at which to compute the error.
* @param discreteValues Discrete assignment for a specific mode sequence.
* @return double
*/
double error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const;

/**
* @brief Compute conditional error for each discrete assignment,
* and return as a tree.
*
* @param continuousValues Continuous values at which to compute the error.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;

/**
* @brief Compute unnormalized probability q(μ|M),
* for each discrete assignment, and return as a tree.
* q(μ|M) is the unnormalized probability at the MLE point μ,
* conditioned on the discrete variables.
*
* @param continuousValues Continuous values at which to compute the
* probability.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> probPrime(
const VectorValues &continuousValues) const;

/// @}

private:
Expand Down
54 changes: 54 additions & 0 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -423,4 +423,58 @@ const Ordering HybridGaussianFactorGraph::getHybridOrdering() const {
return ordering;
}

/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree(0.0);

// Iterate over each factor.
for (size_t idx = 0; idx < size(); idx++) {
AlgebraicDecisionTree<Key> factor_error;

if (factors_.at(idx)->isHybrid()) {
// If factor is hybrid, select based on assignment.
GaussianMixtureFactor::shared_ptr gaussianMixture =
boost::static_pointer_cast<GaussianMixtureFactor>(factors_.at(idx));
// Compute factor error.
factor_error = gaussianMixture->error(continuousValues);

// If first factor, assign error, else add it.
if (idx == 0) {
error_tree = factor_error;
} else {
error_tree = error_tree + factor_error;
}

} else if (factors_.at(idx)->isContinuous()) {
// If continuous only, get the (double) error
// and add it to the error_tree
auto hybridGaussianFactor =
boost::static_pointer_cast<HybridGaussianFactor>(factors_.at(idx));
GaussianFactor::shared_ptr gaussian = hybridGaussianFactor->inner();

// Compute the error of the gaussian factor.
double error = gaussian->error(continuousValues);
// Add the gaussian factor error to every leaf of the error tree.
error_tree = error_tree.apply(
[error](double leaf_value) { return leaf_value + error; });

} else if (factors_.at(idx)->isDiscrete()) {
// If factor at `idx` is discrete-only, we skip.
continue;
}
}

return error_tree;
}

/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues);
AlgebraicDecisionTree<Key> prob_tree =
error_tree.apply([](double error) { return exp(-error); });
return prob_tree;
}

} // namespace gtsam
27 changes: 25 additions & 2 deletions gtsam/hybrid/HybridGaussianFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class JacobianFactor;

/**
* @brief Main elimination function for HybridGaussianFactorGraph.
*
*
* @param factors The factor graph to eliminate.
* @param keys The elimination ordering.
* @return The conditional on the ordering keys and the remaining factors.
Expand Down Expand Up @@ -99,11 +99,12 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
using shared_ptr = boost::shared_ptr<This>; ///< shared_ptr to This

using Values = gtsam::Values; ///< backwards compatibility
using Indices = KeyVector; ///> map from keys to values
using Indices = KeyVector; ///< map from keys to values

/// @name Constructors
/// @{

/// @brief Default constructor.
HybridGaussianFactorGraph() = default;

/**
Expand Down Expand Up @@ -170,6 +171,28 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
}
}

/**
* @brief Compute error for each discrete assignment,
* and return as a tree.
*
* Error \f$ e = \Vert x - \mu \Vert_{\Sigma} \f$.
*
* @param continuousValues Continuous values at which to compute the error.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> error(const VectorValues& continuousValues) const;

/**
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
* for each discrete assignment, and return as a tree.
*
* @param continuousValues Continuous values at which to compute the
* probability.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> probPrime(
const VectorValues& continuousValues) const;

/**
* @brief Return a Colamd constrained ordering where the discrete keys are
* eliminated after the continuous keys.
Expand Down
Loading

0 comments on commit 07d0a03

Please sign in to comment.